Image Classification¶

Motivation¶

  1. Implementing and evaluating a multi-layer perceptron (MLP) and convolutional neural network (CNN) in solving a classification problem
  2. Building, evaluating, and finetuning a CNN on an image dataset from development to testing
  3. Tackling overfitting using strategies such as data augmentation and drop out
  4. Fine tuning a model
  5. Comparing the performance of a new model with an off-the-shelf model (AlexNet)
  6. Gaining a deeper understanding of model performance using visualisations from Grad-CAM.

Setup and resources¶

Having a GPU will speed up the training process. See the provided document on Minerva about setting up a working environment for various ways to access a GPU. We highly recommend you use platforms such as Colab.

Please implement the coursework using Python and PyTorch, and refer to the notebooks and exercises provided.

This coursework will use a subset of images from Tiny ImageNet, which is a subset of the ImageNet dataset. Our subset of Tiny ImageNet contains 30 different categories, we will refer to it as TinyImageNet30. The training set has 450 resized images (64x64 pixels) for each category (13,500 images in total). You can download the training and test set from a direct link or the Kaggle challenge website:

Direct access to data is possible by clicking here, please use your university email to access this

Access data through Kaggle webpage

In [21]:
import math

import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.hub import load_state_dict_from_url

from PIL import Image
import matplotlib.pyplot as plt
In [2]:
import pathlib
import os
import glob
from torch.utils.data import Dataset,DataLoader

import cv2 as cv
import time
from torchcam.methods import SmoothGradCAMpp
import torchvision
from sklearn import metrics
import pandas as pd
from torch import optim
from torch.optim import SGD, Adam
from collections import OrderedDict
from scipy import interp
from sklearn.metrics import roc_curve, auc, f1_score, precision_recall_curve, average_precision_score
from itertools import cycle
import torchvision.models as models
In [103]:
# always check your version
print(torch.__version__)
1.13.0

One challenge of building a deep learning model is to choose an architecture that can learn the features in the dataset without being unnecessarily complex. The first part of the coursework involves building a CNN and training it on TinyImageNet30.

Overview:¶

1. Function implementation

  • 1.1 PyTorch Dataset and DataLoader classes
  • 1.2 PyTorch Model class for a simple MLP model
  • 1.3 PyTorch Model class for a simple CNN model

2. Model training

  • 2.1 Train on TinyImageNet30 dataset
  • 2.2 Generate confusion matrices and ROC curves
  • 2.3 Strategies for tackling overfitting
    • 2.3.1 Data augmentation
    • 2.3.2 Dropout
    • 2.3.3 Hyperparameter tuning (e.g. changing learning rate)

3. Model Fine-tuning on CIFAR10 dataset

  • 3.1 Fine-tune model (initialise your model with pretrained weights from (2))
  • 3.2 Fine-tune model with frozen base convolution layers
  • 3.3 Compare complete model retraining with pretrained weights and with frozen layers. Comment on what you observe?

4. Model testing

  • 4.1 Test final model in (2) on test set - code to do this
  • 4.2 Upload result to Kaggle

5. Model comparison

  • 5.1 Load pretrained AlexNet and finetune on TinyImageNet30 until model convergence
  • 5.2 Compare the results of your CNN model with pretrained AlexNet on the same validation set. Provide performance values (loss graph, confusion matrix, top-1 accuracy, execution time)

6. Interpretation of results

  • 6.1 Use grad-CAM on your model and on AlexNet
  • 6.2 Visualise and compare the results from your model and from AlexNet
  • 6.3 Comment on :
    • why the network predictions were correct or not correct in predictions?
    • what can you do to improve results further?

1. Function implementation

  • 1.1 PyTorch Dataset and DataLoader classes
  • 1.2 PyTorch Model class for a simple MLP model
  • 1.3 PyTorch Model class for a simple CNN model

2. Model training

  • 2.1 Train on TinyImageNet30 dataset
  • 2.2 Generate confusion matrices and ROC curves
  • 2.3 Strategies for tackling overfitting
    • 2.3.1 Data augmentation
    • 2.3.2 Dropout
    • 2.3.3 Hyperparameter tuning (e.g. changing learning rate)

3. Model Fine-tuning on CIFAR10 dataset

  • 3.1 Fine-tune your model (initialise your model with pretrained weights from (2))
  • 3.2 Fine-tune model with frozen base convolution layers
  • 3.3 Compare complete model retraining with pretrained weights and with frozen layers. Comment on what you observe?

4. Model testing

  • 4.1 Test your final model in (2) on test set - code to do this
  • 4.2 Upload your result to Kaggle

5. Model comparison

  • 5.1 Load pretrained AlexNet and finetune on TinyImageNet30 until model convergence
  • 5.2 Compare the results of your CNN model with pretrained AlexNet on the same validation set. Provide performance values (loss graph, confusion matrix, top-1 accuracy, execution time)

6. Interpretation of results

  • 6.1 Use grad-CAM on your model and on AlexNet
  • 6.2 Visualise and compare the results from your model and from AlexNet
  • 6.3 Comment on :
    • why the network predictions were correct or not correct in your predictions?
    • what can you do to improve your results further?

1 Function implementations¶

1.1 Dataset class¶

Write a PyTorch Dataset class (an example here for reference) which loads the TinyImage30 dataset and DataLoaders for training and validation sets.

In [39]:
# TO COMPLETE
classes = {}
with open("./comp5625M_data_assessment_1/class.txt", "r") as f:  # open class file to get all classes and labels
    data = f.read().splitlines()
    for item in data:
        item_class = item.strip("\t").split("\t")
        label, category  = item_class[0], item_class[1]
        classes[category] = label
classes
Out[39]:
{'baboon': '0',
 'banana': '1',
 'bee': '2',
 'bison': '3',
 'butterfly': '4',
 'candle': '5',
 'cardigan': '6',
 'chihuahua': '7',
 'elephant': '8',
 'espresso': '9',
 'fly': '10',
 'goldfish': '11',
 'goose': '12',
 'grasshopper': '13',
 'hourglass': '14',
 'icecream': '15',
 'ipod': '16',
 'jellyfish': '17',
 'koala': '18',
 'ladybug': '19',
 'lion': '20',
 'mushroom': '21',
 'penguin': '22',
 'pig': '23',
 'pizza': '24',
 'pretzel': '25',
 'redpanda': '26',
 'refrigerator': '27',
 'sombrero': '28',
 'umbrella': '29'}
In [38]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
os.environ['CUDA_VISIBLE_DEVICES'] ='0'
Using device: cuda
In [4]:
train_transformer = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean = [0.485, 0.456, 0.406], 
                  std = [0.229, 0.224, 0.225]),
        ])
class MyDataset(Dataset):
    def __init__(self, data_type, transform=train_transformer):
        '''
        data_type : ["train_set", "test_set"]
        '''
        # this is the dictionary path of dataset
        root_path = "./comp5625M_data_assessment_1/"
        # get type of dataset
        self.data_type = data_type
        # join the file path above the categories dictionary of the pictures
        data_root = pathlib.Path(root_path+self.data_type+"/"+self.data_type)

    # obtain all files and all subfiles absolute paths  in this dictionary
        if self.data_type == "train_set":
            all_image_paths = list(data_root.glob("*/*"))
            self.all_image_paths = all_image_paths
            # find the label from the global variable: classes, based on each path class name 
            # the function path.parent.name can provide the father dictionary's name, which is the class name
            self.all_image_labels = [int(classes[path.parent.name]) for path in all_image_paths]
            self.all_image_paths = [str(path) for path in all_image_paths]  
            self.transform = transform

        else:
            all_image_paths = list(data_root.glob("*/"))
            self.all_image_paths = [str(path) for path in all_image_paths]
            # for save result of test csv file
            self.all_image_labels = [str(path) for path in all_image_paths]
            self.transform = transform

    def __getitem__(self, index):
        img = cv.imread(self.all_image_paths[index])
        img=self.transform(img)
        label = self.all_image_labels[index]
        return img, label
    def __len__(self):
        return len(self.all_image_paths)

1.2 Define a MLP model class¶

Create a new model class using a combination of:

  • Input Units
  • Hidden Units
  • Output Units
  • Activation functions
  • Loss function
  • Optimiser
In [5]:
# TO COMPLETE
# define a MLP Model class
class MLP_Class(nn.Module):
    def __init__(self):
        super(MLP_Class,self).__init__()
        self.layer = nn.Sequential(
            OrderedDict(
                [   
                    ("flatten", nn.Flatten()),
                    ("hidden_1_layer", nn.Linear(3*64*64,1024)),
                    ('relu1', nn.ReLU()),
                    ("hidden_2_layer", nn.Linear(1024,512)),
                    ('relu2', nn.ReLU()),
                    ("hidden_3_layer", nn.Linear(512, 30)),
                ]
            ))
    def forward(self,x):
        x = self.layer(x)
        return x
MLP_model = MLP_Class()
MLP_model = MLP_model.to(device)
print(MLP_model)
MLP_Class(
  (layer): Sequential(
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (hidden_1_layer): Linear(in_features=12288, out_features=1024, bias=True)
    (relu1): ReLU()
    (hidden_2_layer): Linear(in_features=1024, out_features=512, bias=True)
    (relu2): ReLU()
    (hidden_3_layer): Linear(in_features=512, out_features=30, bias=True)
  )
)

1.3 Define a CNN model class¶

Create a new model class using a combination of:

  • Convolution layers
  • Activation functions (e.g. ReLU)
  • Maxpooling layers
  • Fully connected layers
  • Loss function
  • Optimiser
In [6]:
class CNN_Class(nn.Module):
    def __init__(self):
        super(CNN_Class,self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flc1 = nn.Linear(64*8*8,1024)
        self.flc2 = nn.Linear(1024,30)
    def forward(self,x):
        x = self.maxpool1(nn.functional.relu(self.conv1(x)))
        x = self.maxpool2(nn.functional.relu(self.conv2(x)))
        x = self.maxpool3(nn.functional.relu(self.conv3(x)))
        x = x.view(-1,64*8*8)
        x = nn.functional.relu(self.flc1(x))
        x=self.flc2(x)
        return x
CNN_model = CNN_Class()
CNN_model = CNN_model.to(device)
print(CNN_model)
CNN_Class(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (maxpool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (maxpool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (maxpool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (flc1): Linear(in_features=4096, out_features=1024, bias=True)
  (flc2): Linear(in_features=1024, out_features=30, bias=True)
  (softmax): Softmax(dim=1)
)

2 Model training¶

2.1 Train both MLP and CNN models - show loss and accuracy graphs side by side¶

Train your model on the TinyImageNet30 dataset. Split the data into train and validation sets to determine when to stop training. Use seed at 0 for reproducibility and test_ratio=0.2 (validation data)

Display the graph of training and validation loss over epochs and accuracy over epochs to show how you determined the optimal number of training epochs. Top-k accuracy implementation is provided for you below.

Please leave the graph clearly displayed. Please use the same graph to plot graphs for both train and validation.

In [7]:
# split train dataset to train_set and validate_set based on test_ratio=0.2
train_set = MyDataset("train_set")
length=len(train_set)
train_size,validate_size=int(0.8*length),int(0.2*length)
train_set,validate_set=torch.utils.data.random_split(train_set,[train_size,validate_size],generator=torch.Generator().manual_seed(0))
print(len(train_set),len(validate_set))

train_loader = DataLoader(
    train_set,
    batch_size = 64,
    shuffle = True)
validate_loader = DataLoader(
    validate_set,
    batch_size = 64,
    shuffle = True)
10800 2700
In [9]:
def train(train_set, model, criterion, optimizer):
    model.train()
    n = 0
    train_running_loss = 0.0
    train_running_accuracy = 0.0
    for data in train_set:
        images, labels = data
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        optimizer.zero_grad()
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_running_loss += loss.item() 
        train_running_accuracy += topk_accuracy(output = outputs, target = labels, topk=(1,))[0].cpu().float()
        n += 1
    return train_running_loss / n, (train_running_accuracy / n).cpu().numpy()

def validate(val_set, model, criterion, optimizer):
    model.eval()
    n = 0
    validate_running_loss = 0.0
    validate_running_accuracy = 0.0
    with torch.no_grad():
        for data in val_set:
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            validate_running_loss += loss.item() 
            validate_running_accuracy += topk_accuracy(output = outputs, target = labels, topk=(1,))[0]
            n += 1
    return validate_running_loss / n, (validate_running_accuracy / n).cpu().numpy()
In [10]:
# Define top-*k* accuracy 
def topk_accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res
In [24]:
#TO COMPLETE --> Running you MLP model class

nepochs = 100
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(MLP_model.parameters(), 0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,2, gamma=0.8, last_epoch=-1)
best_loss = 1000.0
train_loss, validate_loss, train_accuracy, validate_accuracy = [], [], [], []

for epoch in range(nepochs):
    train_running_loss , train_running_accuracy = train(train_loader, MLP_model, criterion, optimizer)
    train_loss.append(train_running_loss)
    train_accuracy.append(train_running_accuracy)
    validate_running_loss , validate_running_accuracy = validate(validate_loader, MLP_model, criterion, optimizer)
    validate_loss.append(validate_running_loss)
    validate_accuracy.append(validate_running_accuracy)
    scheduler.step()
    if validate_running_loss < best_loss:
        best_loss = validate_running_loss
        torch.save(MLP_model.state_dict(), './MLP_model.pt')
In [27]:
# Your graph

x_axis = np.arange(1,nepochs+1,1,int)
fig,axs=plt.subplots(2,1,sharex=True,sharey=False)
fig.suptitle('The loss and accuracy of train and validate sets')
axs[0].plot(x_axis,train_loss,'r--',label='MLP_train_loss')
axs[0].plot(x_axis,validate_loss,'g--',label='MLP_validate_loss')
axs[1].plot(x_axis,train_accuracy,'b--',label='MLP_train_accuracy')
axs[1].plot(x_axis,validate_accuracy,'y--',label='MLP_validate_accuracy')
axs[1].set_xlabel('epoch')
axs[0].set_ylabel('loss')
axs[1].set_ylabel('percentage of accuracy')
axs[0].legend()
axs[1].legend()
plt.show()
In [28]:
#TO COMPLETE --> Running you CNN model class

nepochs = 100
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(CNN_model.parameters(), 0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,2, gamma=0.8, last_epoch=-1)
best_loss = 1000.0
CNN_train_loss, CNN_validate_loss, CNN_train_accuracy, CNN_validate_accuracy = [], [], [], []

for epoch in range(nepochs):
    train_running_loss , train_running_accuracy = train(train_loader, CNN_model, criterion, optimizer)
    CNN_train_loss.append(train_running_loss)
    CNN_train_accuracy.append(train_running_accuracy)
    validate_running_loss , validate_running_accuracy = validate(validate_loader, CNN_model, criterion, optimizer)
    CNN_validate_loss.append(validate_running_loss)
    CNN_validate_accuracy.append(validate_running_accuracy)
    scheduler.step()
    if validate_running_loss < best_loss:
        best_loss = validate_running_loss
        torch.save(CNN_model.state_dict(), './CNN_model.pt')
In [29]:
# Your graph

x_axis = np.arange(1,nepochs+1,1,int)
fig,axs=plt.subplots(2,1,sharex=True,sharey=False)
fig.suptitle('The loss and accuracy of train and validate sets')
axs[0].plot(x_axis,CNN_train_loss,'r--',label='CNN_train_loss')
axs[0].plot(x_axis,CNN_validate_loss,'g--',label='CNN_validate_loss')
axs[1].plot(x_axis,CNN_train_accuracy,'b--',label='CNN_train_accuracy')
axs[1].plot(x_axis,CNN_validate_accuracy,'y--',label='CNN_validate_accuracy')
axs[1].set_xlabel('epoch')
axs[0].set_ylabel('loss')
axs[1].set_ylabel('percentage of accuracy')
axs[0].legend()
axs[1].legend()
plt.show()

Comment on your model and results that should include number of parameters in each model and why CNN over MLP for image classification task?

2.2 Generating confusion matrix and ROC curves¶

  • Use your CNN architecture with best accuracy to generate two confusion matrices, one for the training set and another for the validation set. Remember to use the whole validation and training sets, and to include all your relevant code. Display the confusion matrices in a meaningful way that clearly indicates what percentage of the data is represented in each position.
  • Display ROC curve for 5 top classes with area under the curve
In [36]:
# Your code here!
def accuracy(cnfm):
    return cnfm.trace()/cnfm.sum((0,1))

def recalls(cnfm):
    return np.diag(cnfm)/cnfm.sum(1)

def precisions(cnfm):
    return np.diag(cnfm)/cnfm.sum(0)    

num_class = len(classes)
CNN_model.load_state_dict(torch.load('./CNN_model.pt'))
nclasses = len(classes)
cnfm = np.zeros((nclasses,nclasses),dtype=int)
score_list = []     # save predicted score
label_list = [] 
with torch.no_grad():
    for data in validate_loader:
        images, labels = data
        images = images.to(device)
        labels = labels.to(device)
        outputs = CNN_model(images)
        _, predicted = torch.max(outputs, 1)    
        score_tmp = outputs
        for i in range(labels.size(0)):
            cnfm[labels[i].item(),predicted[i].item()] += 1
        score_list.extend(score_tmp.detach().cpu().numpy())
        label_list.extend(labels.cpu().numpy())      
print("Confusion matrix")
print(cnfm)


# show confusion matrix as a grey-level image
plt.imshow(cnfm, cmap='gray')

# show per-class recall and precision
print(f"Accuracy: {accuracy(cnfm) :.1%}")
r = recalls(cnfm)
p = precisions(cnfm)
for i in range(nclasses):
    print(f"Class {list(classes.keys())[i]} : Precision {p[i] :.1%}  Recall {r[i] :.1%}")
Confusion matrix
[[30  0  1  4  0  1  3  3  7  1  1  0  3  0  3  1  1  1  7  2  5  1  1 11
   0  0  3  0  0  0]
 [ 1 29  8  1  0  4  0  0  0  1  1  3  0  0  3  0  2  0  1  2  0  1  3  0
   5  2  0  0  3  1]
 [ 0  1 42  2  8  0  0  1  0  1 10  1  0  7  1  0  1  0  0  9  0  1  0  1
   0  1  1  0  0  3]
 [ 1  1  0 54  0  0  2  1  9  1  0  0  2  1  0  0  2  0  2  0  0  1  2  9
   0  0  3  0  1  0]
 [ 0  0  6  0 79  2  0  1  0  0  2  1  0  4  0  0  0  0  1  1  0  1  0  0
   0  1  0  0  2  3]
 [ 0  7  5  0  0 31  1  7  0  1  1  3  0  0  8  0  3  3  0  4  0  1  2  1
   0  1  2  0  4  4]
 [ 0  2  4  1  1  1 40  2  2  0  0  2  0  4  3  0  4  0  4  0  1  1  5  1
   0  0  2  2  3  7]
 [ 2  0  2  1  0  3  0 32  1  2  0  1  3  0  2  4  5  0  4  3  6  1  0  5
   0  2  5  3  2  2]
 [ 2  0  0 14  0  0  2  0 44  1  0  0  1  1  0  0  3  0  4  1  2  1  1  6
   0  0  2  1  2  1]
 [ 0  2  2  0  0  2  0  3  0 43  1  4  0  0  7  1  7  0  0  2  0  1  1  0
   3  3  1  1  3  0]
 [ 1  0 14  1  4  0  1  1  0  0 45  0  1 11  2  1  0  0  0 10  0  1  0  0
   1  0  0  0  2  1]
 [ 0  3  1  0  0  2  0  1  2  2  0 54  0  3  0  0  0  3  0  6  3  4  1  1
   0  0  2  0  2  2]
 [ 1  0  2  1  1  1  2  3  3  1  2  0 32  1  2  1  4  3  1  1  4  2  4  9
   0  0  2  0  1  4]
 [ 0  3 15  2  3  1  1  1  2  0  7  1  2 46  0  0  1  0  2  4  1  0  1  0
   0  0  0  1  0  0]
 [ 0  3  2  0  0 11  2  4  0  0  2  0  2  2 41  0  5  1  2  1  0  0  4  0
   1  0  0  2  1  4]
 [ 0  3  0  0  1  8  1  4  1  8  3  4  1  1  4 12  3  1  0  5  2  3  0  4
   4  8  0  1  9  3]
 [ 1  2  0  2  0  1  1  3  0  2  1  0  3  1  4  2 34  0  0  0  1  1  2  1
   0  1  0  3  4  2]
 [ 0  1  0  0  0  2  0  0  0  0  1  2  0  1  2  0  0 68  0  3  0  0  0  0
   0  1  0  0  0  2]
 [ 3  0  1  3  1  0  1  2  5  0  1  0  2  0  0  0  0  0 62  0  2  2  0  3
   0  1  2  0  0  2]
 [ 1  6 17  2  1  0  0  3  0  1  4  1  1 12  1  1  2  0  0 40  1  1  1  0
   0  0  1  1  0  0]
 [ 4  1  3  0  0  1  1  6  5  0  0  0  1  2  2  0  0  0  4  1 38  5  0 10
   0  4  1  2  2  2]
 [ 0  1  3  2  4  4  0  2  3  2  1  3  1  3  0  0  1  1  1  1  1 35  0  2
   1  0  3  0  5  2]
 [ 0  1  2  3  1  1  3  4  5  0  0  0  7  0  4  1  0  3  0  0  0  0 35  1
   0  0  0  3  3  2]
 [ 5  0  0  6  0  0  0  4  8  0  1  1  6  2  1  0  0  1  3  0  2  5  2 33
   4  1  2  3  1  2]
 [ 0  3  0  0  1  5  0  1  0  0  0  1  0  3  0  2  0  0  0  0  0  2  0  1
  55  6  0  1  4  2]
 [ 0  3  2  2  1  5  1  3  0  6  1  2  0  0  2  2  1  0  1  2  3  0  0  0
  15 26  1  2  5  0]
 [ 6  0  3  4  0  0  0  3  3  0  1  0  0  0  1  0  0  0  5  3  1  9  0  0
   0  1 52  0  0  2]
 [ 1  3  0  0  0  4  3  1  2  2  0  1  2  0 12  1  8  0  2  0  2  2  3  2
   1  2  1 28  2  6]
 [ 1  1  1  4  0  4  5  7  3  0  0  1  1  3  3  2  3  1  0  1  2  4  0  5
   0  3  2  1 36  6]
 [ 0  7  2  1  2  0  3  2  2  0  2  1  4  5  3  3  6  3  2  4  1  5  2  3
   1  0  1  1  8 22]]
Accuracy: 45.1%
Class baboon : Precision 50.0%  Recall 33.3%
Class banana : Precision 34.9%  Recall 40.8%
Class bee : Precision 30.4%  Recall 46.2%
Class bison : Precision 49.1%  Recall 58.7%
Class butterfly : Precision 73.1%  Recall 76.0%
Class candle : Precision 33.0%  Recall 34.8%
Class cardigan : Precision 54.8%  Recall 43.5%
Class chihuahua : Precision 30.5%  Recall 35.2%
Class elephant : Precision 41.1%  Recall 49.4%
Class espresso : Precision 57.3%  Recall 49.4%
Class fly : Precision 51.1%  Recall 46.4%
Class goldfish : Precision 62.1%  Recall 58.7%
Class goose : Precision 42.7%  Recall 36.4%
Class grasshopper : Precision 40.7%  Recall 48.9%
Class hourglass : Precision 36.9%  Recall 45.6%
Class icecream : Precision 35.3%  Recall 12.8%
Class ipod : Precision 35.4%  Recall 47.2%
Class jellyfish : Precision 76.4%  Recall 81.9%
Class koala : Precision 57.4%  Recall 66.7%
Class ladybug : Precision 37.7%  Recall 40.8%
Class lion : Precision 48.7%  Recall 40.0%
Class mushroom : Precision 38.5%  Recall 42.7%
Class penguin : Precision 50.0%  Recall 44.3%
Class pig : Precision 30.3%  Recall 35.5%
Class pizza : Precision 60.4%  Recall 63.2%
Class pretzel : Precision 40.6%  Recall 30.2%
Class redpanda : Precision 58.4%  Recall 55.3%
Class refrigerator : Precision 50.0%  Recall 30.8%
Class sombrero : Precision 34.3%  Recall 36.0%
Class umbrella : Precision 25.3%  Recall 22.9%
In [43]:
score_array = np.array(score_list)
# make label convert to be onehot form
label_tensor = torch.tensor(label_list)
label_tensor = label_tensor.reshape((label_tensor.shape[0], 1))
label_onehot = torch.zeros(label_tensor.shape[0], num_class)
label_onehot.scatter_(dim=1, index=label_tensor, value=1)
label_onehot = np.array(label_onehot)
 
# call sklearn to calculate the corresponding fpr and tpr of each class
fpr_dict = dict()
tpr_dict = dict()
roc_auc_dict = dict()
for i in range(num_class):
    fpr_dict[i], tpr_dict[i], _ = roc_curve(label_onehot[:, i], score_array[:, i])
    roc_auc_dict[i] = auc(fpr_dict[i], tpr_dict[i])
# micro
fpr_dict["micro"], tpr_dict["micro"], _ = roc_curve(label_onehot.ravel(), score_array.ravel())
roc_auc_dict["micro"] = auc(fpr_dict["micro"], tpr_dict["micro"])
 
# macro
# First aggregate all false positive rates
all_fpr = np.unique(np.concatenate([fpr_dict[i] for i in range(num_class)]))
# Then interpolate all ROC curves at this points
mean_tpr = np.zeros_like(all_fpr)
for i in range(num_class):
     mean_tpr += interp(all_fpr, fpr_dict[i], tpr_dict[i])
# Finally average it and compute AUC
mean_tpr /= num_class
fpr_dict["macro"] = all_fpr
tpr_dict["macro"] = mean_tpr
roc_auc_dict["macro"] = auc(fpr_dict["macro"], tpr_dict["macro"])
roc_auc_dict_order=sorted(roc_auc_dict.items(),key=lambda x:x[1],reverse=True)

# draw the average roc curve of all classes
plt.figure()
lw = 2
plt.plot(fpr_dict["micro"], tpr_dict["micro"],
        label='micro-average ROC curve (area = {0:0.2f})'
                   ''.format(roc_auc_dict["micro"]),
        color='deeppink', linestyle=':', linewidth=4)
 
plt.plot(fpr_dict["macro"], tpr_dict["macro"],
        label='macro-average ROC curve (area = {0:0.2f})'
                   ''.format(roc_auc_dict["macro"]),
        color='navy', linestyle=':', linewidth=4)
 
colors = cycle(['aqua', 'darkorange', 'cornflowerblue'])


for i, color in zip(range(5), colors):
    category = roc_auc_dict_order[i][0]
    plt.plot(fpr_dict[category], tpr_dict[category], color=color, lw=lw,
                 label='ROC curve of class {0} (area = {1:0.2f})'
                       ''.format(category, roc_auc_dict_order[i][1]))
plt.plot([0, 1], [0, 1], 'k--', lw=lw)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Some extension of Receiver operating characteristic to multi-class')
plt.legend(loc="lower right")
plt.savefig('set113_roc.jpg')
plt.show()
C:\Users\Administrator\AppData\Local\Temp\ipykernel_9780\3326775581.py:32: DeprecationWarning: scipy.interp is deprecated and will be removed in SciPy 2.0.0, use numpy.interp instead
  mean_tpr += interp(all_fpr, fpr_dict[i], tpr_dict[i])

Note: All parts below here relate to the CNN model only and not the MLP! You are advised to use your final CNN model only for each of the following parts.

2.3 Strategies for tackling overfitting¶

Using your (final) CNN model, use the strategies below to avoid overfitting. You can reuse the network weights from previous training, often referred to as fine tuning.

  • 2.3.1 Data augmentation
  • 2.3.2 Dropout
  • 2.3.3 Hyperparameter tuning (e.g. changing learning rate)

Plot loss and accuracy graphs per epoch side by side for each implemented strategy.

2.3.1 Data augmentation¶

Implement at least five different data augmentation techniques that should include both photometric and geometric augmentations.

Provide graph and comment on what you observe

In [11]:
data_augmentation_transform = transforms.Compose([
                    transforms.ToPILImage(),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.RandomVerticalFlip(p=0.5),
                    transforms.RandomRotation((-20,20)),
                    transforms.ColorJitter(hue=0.2, saturation=0.2, brightness=0.2),
                    transforms.RandomResizedCrop(64,scale=(0.7,1.0)),
                    transforms.ToTensor(),
                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
                    ])

train_set = MyDataset("train_set",transform=data_augmentation_transform)
length=len(train_set)
train_size,validate_size=int(0.8*length),int(0.2*length)
train_set,validate_set=torch.utils.data.random_split(train_set,[train_size,validate_size],generator=torch.Generator().manual_seed(0))

train_loader = DataLoader(
    train_set,
    batch_size = 64,
    shuffle = True)
validate_loader = DataLoader(
    validate_set,
    batch_size = 64,
    shuffle = True)
In [54]:
# Your code here!
nepochs = 100
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(CNN_model.parameters(), 0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,2, gamma=0.8, last_epoch=-1)
best_loss = 1000.0
CNN_train_loss, CNN_validate_loss, CNN_train_accuracy, CNN_validate_accuracy = [], [], [], []

for epoch in range(nepochs):
    train_running_loss , train_running_accuracy = train(train_loader, CNN_model, criterion, optimizer)
    CNN_train_loss.append(train_running_loss)
    CNN_train_accuracy.append(train_running_accuracy)
    validate_running_loss , validate_running_accuracy = validate(validate_loader, CNN_model, criterion, optimizer)
    CNN_validate_loss.append(validate_running_loss)
    CNN_validate_accuracy.append(validate_running_accuracy)
    scheduler.step()
    if validate_running_loss < best_loss:
        best_loss = validate_running_loss
        torch.save(CNN_model.state_dict(), './CNN_model.pt')
        
x_axis = np.arange(1,nepochs+1,1,int)
fig,axs=plt.subplots(2,1,sharex=True,sharey=False)
fig.suptitle('The loss and accuracy of train and validate sets')
axs[0].plot(x_axis,CNN_train_loss,'r--',label='CNN_train_loss')
axs[0].plot(x_axis,CNN_validate_loss,'g--',label='CNN_validate_loss')
axs[1].plot(x_axis,CNN_train_accuracy,'b--',label='CNN_train_accuracy')
axs[1].plot(x_axis,CNN_validate_accuracy,'y--',label='CNN_validate_accuracy')
axs[1].set_xlabel('epoch')
axs[0].set_ylabel('loss')
axs[1].set_ylabel('percentage of accuracy')
axs[0].legend()
axs[1].legend()
plt.show()

Before data augmentation, the training set was overfitting after a dozen epochs. the training set was approaching 100% too early, but the accuracy of the validation set was no longer increasing, proving that the model was learning too much useless information. By applying photometric and geometric augmentations to the images, the overfitting problem was solved to some extent and the accuracy of the validation set improved.

2.3.2 Dropout¶

Implement dropout in your model

Provide graph and comment on your choice of proportion used

In [56]:
# Your code here!

possibility = [0.2,0.3,0.4,0.5,0.6,0.7]
CNN_train_loss =  {"0.2":[],"0.3":[],"0.4":[],"0.5":[],"0.6":[],"0.7":[]}
CNN_validate_loss =   {"0.2":[],"0.3":[],"0.4":[],"0.5":[],"0.6":[],"0.7":[]}
CNN_train_accuracy =   {"0.2":[],"0.3":[],"0.4":[],"0.5":[],"0.6":[],"0.7":[]}
CNN_validate_accuracy =   {"0.2":[],"0.3":[],"0.4":[],"0.5":[],"0.6":[],"0.7":[]}

for poss in possibility:

    class CNN_Class_Improved(nn.Module):
        def __init__(self):
            super(CNN_Class_Improved,self).__init__()
            self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
            self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
            self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
            self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
            self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
            self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
            self.flc1 = nn.Linear(64*8*8,1024)
            self.dropout = nn.Dropout(p=poss)
            self.flc2 = nn.Linear(1024,30)

        def forward(self,x):
            x = self.maxpool1(nn.functional.relu(self.conv1(x)))
            x = self.maxpool2(nn.functional.relu(self.conv2(x)))
            x = self.maxpool3(nn.functional.relu(self.conv3(x)))
            x = x.view(-1,64*8*8)
            x = self.dropout(x)
            x = nn.functional.relu(self.flc1(x))
            x = self.flc2(x)
            return x
    CNN_model_Improved =CNN_Class_Improved()

    CNN_model_Improved = CNN_model_Improved.to(device)


    nepochs = 100
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(CNN_model_Improved.parameters(), 0.001)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,2, gamma=0.8, last_epoch=-1)
    best_loss = 1000.0

    for epoch in range(nepochs):
        train_running_loss , train_running_accuracy = train(train_loader, CNN_model_Improved, criterion, optimizer)
        CNN_train_loss[str(poss)].append(train_running_loss)
        CNN_train_accuracy[str(poss)].append(train_running_accuracy)
        validate_running_loss , validate_running_accuracy = validate(validate_loader, CNN_model_Improved, criterion, optimizer)
        CNN_validate_loss[str(poss)].append(validate_running_loss)
        CNN_validate_accuracy[str(poss)].append(validate_running_accuracy)
        scheduler.step()
        if validate_running_loss < best_loss:
            best_loss = validate_running_loss
            torch.save(CNN_model_Improved.state_dict(), './CNN_model_Improved.pt')

x_axis = np.arange(1,nepochs+1,1,int)
fig,axs=plt.subplots(2,1,figsize=(15,20),sharex=True,sharey=False)
fig.suptitle('The loss and accuracy of train and validate sets with dropout possibility')
for poss in possibility:
    axs[0].plot(x_axis,CNN_train_loss[str(poss)],label='CNN_train_loss in '+str(poss))
    axs[0].plot(x_axis,CNN_validate_loss[str(poss)],label='CNN_validate_loss in '+str(poss))
    axs[1].plot(x_axis,CNN_train_accuracy[str(poss)],label='CNN_train_accuracy in '+str(poss))
    axs[1].plot(x_axis,CNN_validate_accuracy[str(poss)],label='CNN_validate_accuracy in '+str(poss))
axs[1].set_xlabel('epoch')
axs[0].set_ylabel('loss')
axs[1].set_ylabel('percentage of accuracy')
axs[0].legend()
axs[1].legend()
plt.show()

As can be seen from the two plots above, when the probability of the Dropout layer is set to 0.4, the Loss is lower and Accuracy is higher for both the training and validation sets. Therefore I can tell that p=0.4 is more suitable for this CNN model. Below I redefine the model Class as CNN_Class_Improved, where the Dropout layer is set with a probability of 0.4.

In [12]:
class CNN_Class_Improved(nn.Module):
    def __init__(self):
        super(CNN_Class_Improved,self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flc1 = nn.Linear(64*8*8,1024)
        self.dropout = nn.Dropout(p=0.4)
        self.flc2 = nn.Linear(1024,30)

    def forward(self,x):
        x = self.maxpool1(nn.functional.relu(self.conv1(x)))
        x = self.maxpool2(nn.functional.relu(self.conv2(x)))
        x = self.maxpool3(nn.functional.relu(self.conv3(x)))
        x = x.view(-1,64*8*8)
        x = self.dropout(x)
        x = nn.functional.relu(self.flc1(x))
        x = self.flc2(x)
        return x
CNN_model_Improved =CNN_Class_Improved()

CNN_model_Improved = CNN_model_Improved.to(device)
In [60]:
CNN_train_loss, CNN_validate_loss, CNN_train_accuracy, CNN_validate_accuracy = [], [], [], []
nepochs = 100
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(CNN_model_Improved.parameters(), 0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,2, gamma=0.8, last_epoch=-1)
best_loss = 1000.0

for epoch in range(nepochs):
    train_running_loss , train_running_accuracy = train(train_loader, CNN_model_Improved, criterion, optimizer)
    CNN_train_loss.append(train_running_loss)
    CNN_train_accuracy.append(train_running_accuracy)
    validate_running_loss , validate_running_accuracy = validate(validate_loader, CNN_model_Improved, criterion, optimizer)
    CNN_validate_loss.append(validate_running_loss)
    CNN_validate_accuracy.append(validate_running_accuracy)
    scheduler.step()
    if validate_running_loss < best_loss:
        best_loss = validate_running_loss
        torch.save(CNN_model_Improved.state_dict(), './CNN_model_Improved.pt')

x_axis = np.arange(1,nepochs+1,1,int)
fig,axs=plt.subplots(2,1,figsize=(15,20),sharex=True,sharey=False)
fig.suptitle('The loss and accuracy of train and validate sets with 0.4 dropout possibility')
axs[0].plot(x_axis,CNN_train_loss,label='CNN_train_loss')
axs[0].plot(x_axis,CNN_validate_loss,label='CNN_validate_loss')
axs[1].plot(x_axis,CNN_train_accuracy,label='CNN_train_accuracy')
axs[1].plot(x_axis,CNN_validate_accuracy,label='CNN_validate_accuracy')
axs[1].set_xlabel('epoch')
axs[0].set_ylabel('loss')
axs[1].set_ylabel('percentage of accuracy')
axs[0].legend()
axs[1].legend()
plt.show()

2.3.3 Hyperparameter tuning¶

Use learning rates [0.1, 0.001, 0.0001]

Provide separate graphs for loss and accuracy, each showing performance at three different learning rates

In [32]:
nepochs = 100
criterion = nn.CrossEntropyLoss()
best_loss = 1000.0
train_loss_lr_01 = []
train_acc_lr_01 = []
validate_loss_lr_01 = []
validate_acc_lr_01 = []

optimizer = torch.optim.Adam(CNN_model_Improved.parameters(), lr=0.1)

for epoch in range(nepochs):
    train_running_loss , train_running_accuracy = train(train_loader, CNN_model_Improved, criterion, optimizer)
    train_loss_lr_01.append(train_running_loss)
    train_acc_lr_01.append(train_running_accuracy)
    validate_running_loss , validate_running_accuracy = validate(validate_loader, CNN_model_Improved, criterion, optimizer)
    validate_loss_lr_01.append(validate_running_loss)
    validate_acc_lr_01.append(validate_running_accuracy)

    if validate_running_loss < best_loss:
        best_loss = validate_running_loss
        torch.save(CNN_model_Improved.state_dict(), './CNN_model_Improved.pt')
    print('epoch:{},train_loss:{}, train_acc:{}, val_loss:{}, val_acc:{}'.format(epoch+1,train_running_loss,train_running_accuracy,validate_running_loss,validate_running_accuracy))
epoch:1,train_loss:2164.766256813467, train_acc:3.2513561248779297, val_loss:3.4151479754337046, val_acc:3.3793604373931885
epoch:2,train_loss:3.417014720171866, train_acc:3.155818462371826, val_loss:3.4172210637913194, val_acc:3.1371123790740967
epoch:3,train_loss:3.4147649728334866, train_acc:3.260601758956909, val_loss:3.4149644984755407, val_acc:3.5004844665527344
epoch:4,train_loss:3.4186991643623488, train_acc:3.026380777359009, val_loss:3.4159545177637143, val_acc:3.4156975746154785
epoch:5,train_loss:3.4175619698135105, train_acc:3.2575197219848633, val_loss:3.415543012840803, val_acc:3.2703487873077393
epoch:6,train_loss:3.417246429172493, train_acc:3.125, val_loss:3.4096909234690114, val_acc:3.16133713722229
epoch:7,train_loss:3.418316842536249, train_acc:3.0171351432800293, val_loss:3.417351800341939, val_acc:3.5368216037750244
epoch:8,train_loss:3.418994481746967, train_acc:3.260601758956909, val_loss:3.4151616262835125, val_acc:3.633720874786377
epoch:9,train_loss:3.420121513174836, train_acc:3.0448715686798096, val_loss:3.414994395056436, val_acc:3.4883720874786377
epoch:10,train_loss:3.4162561258620765, train_acc:3.402366876602173, val_loss:3.4193347165750905, val_acc:2.579941749572754
epoch:11,train_loss:3.419583608412884, train_acc:2.8907790184020996, val_loss:3.426316582879355, val_acc:3.3430233001708984
epoch:12,train_loss:3.4176741506926405, train_acc:3.590359926223755, val_loss:3.4160396331964535, val_acc:3.633720874786377
epoch:13,train_loss:3.4167988328538703, train_acc:3.3099112510681152, val_loss:3.4130693701810615, val_acc:3.4156975746154785
epoch:14,train_loss:3.419260002452241, train_acc:2.98939847946167, val_loss:3.4215198117633197, val_acc:3.16133713722229
epoch:15,train_loss:3.417573000552386, train_acc:3.1342456340789795, val_loss:3.4205423454905666, val_acc:3.3066859245300293
epoch:16,train_loss:3.418926923232671, train_acc:3.0602810382843018, val_loss:3.4173313074333724, val_acc:3.4641470909118652
epoch:17,train_loss:3.416952684786193, train_acc:3.174309492111206, val_loss:3.4297832500102907, val_acc:3.1734495162963867
epoch:18,train_loss:3.419816007275553, train_acc:3.4054486751556396, val_loss:3.416986337927885, val_acc:3.427809953689575
epoch:19,train_loss:3.4161838441205448, train_acc:3.2051284313201904, val_loss:3.413021359332772, val_acc:3.3430233001708984
epoch:20,train_loss:3.4196937112413215, train_acc:3.248274087905884, val_loss:3.4172293363615522, val_acc:3.4520349502563477
epoch:21,train_loss:3.417763116091666, train_acc:3.568787097930908, val_loss:3.414470162502555, val_acc:3.4156975746154785
epoch:22,train_loss:3.4174171650903467, train_acc:3.2729289531707764, val_loss:3.4207970042561375, val_acc:3.16133713722229
epoch:23,train_loss:3.4184304313546807, train_acc:3.282174587249756, val_loss:3.414499371550804, val_acc:3.6579458713531494
epoch:24,train_loss:3.41857082321799, train_acc:3.1342456340789795, val_loss:3.4210255312365154, val_acc:3.234011650085449
epoch:25,train_loss:3.416219210483619, train_acc:3.4609220027923584, val_loss:3.4134753526643267, val_acc:3.4520349502563477
epoch:26,train_loss:3.4197405767158644, train_acc:3.088017702102661, val_loss:3.4153738520866215, val_acc:3.5731587409973145
epoch:27,train_loss:3.420783406884007, train_acc:3.4270217418670654, val_loss:3.4164660808651948, val_acc:2.579941749572754
epoch:28,train_loss:3.4167481569143443, train_acc:3.2051284313201904, val_loss:3.4207183094911797, val_acc:3.1734495162963867
epoch:29,train_loss:3.4205114516986193, train_acc:2.847633123397827, val_loss:3.418117695076521, val_acc:3.4156975746154785
epoch:30,train_loss:3.4190025611742008, train_acc:2.933925151824951, val_loss:3.418670748555383, val_acc:3.4520349502563477
epoch:31,train_loss:3.4182936806650557, train_acc:3.072608470916748, val_loss:3.4141087698382, val_acc:2.616279125213623
epoch:32,train_loss:3.4167870247857812, train_acc:3.52255916595459, val_loss:3.410075542538665, val_acc:2.773740291595459
epoch:33,train_loss:3.4175721456313273, train_acc:3.158900499343872, val_loss:3.4164779851602955, val_acc:3.427809953689575
epoch:34,train_loss:3.4188687293487185, train_acc:3.2174556255340576, val_loss:3.4167783703914907, val_acc:3.7306201457977295
epoch:35,train_loss:3.4183290385635647, train_acc:3.026380777359009, val_loss:3.4241222503573394, val_acc:3.125
epoch:36,train_loss:3.416948575239915, train_acc:3.0602810382843018, val_loss:3.4283005470453305, val_acc:3.125
epoch:37,train_loss:3.422103128489658, train_acc:3.161982297897339, val_loss:3.4196407850398574, val_acc:3.4641470909118652
epoch:38,train_loss:3.420339352985811, train_acc:3.0787723064422607, val_loss:3.411231368087059, val_acc:2.616279125213623
epoch:39,train_loss:3.4194834387514015, train_acc:3.1527366638183594, val_loss:3.425304961758991, val_acc:3.3430233001708984
epoch:40,train_loss:3.4183861207679884, train_acc:3.075690507888794, val_loss:3.4169876187346704, val_acc:3.3066859245300293
epoch:41,train_loss:3.41841212532224, train_acc:3.0109713077545166, val_loss:3.4204413391822994, val_acc:3.4641470909118652
epoch:42,train_loss:3.419490743670943, train_acc:3.0664448738098145, val_loss:3.4204198648763255, val_acc:3.3430233001708984
epoch:43,train_loss:3.417487470355965, train_acc:3.2359466552734375, val_loss:3.419445808543715, val_acc:3.3430233001708984
epoch:44,train_loss:3.4181505279428155, train_acc:3.3592207431793213, val_loss:3.422358607136926, val_acc:3.4520349502563477
epoch:45,train_loss:3.4212284158672808, train_acc:3.3099112510681152, val_loss:3.4102145461148994, val_acc:3.16133713722229
epoch:46,train_loss:3.4180980814984565, train_acc:3.023298740386963, val_loss:3.4169417314751205, val_acc:2.9796512126922607
epoch:47,train_loss:3.4190124929303956, train_acc:3.2051284313201904, val_loss:3.4161524384520776, val_acc:3.5368216037750244
epoch:48,train_loss:3.415075201960005, train_acc:3.23286509513855, val_loss:3.423741251923317, val_acc:3.633720874786377
epoch:49,train_loss:3.4189514478988197, train_acc:3.2883384227752686, val_loss:3.4172512154246486, val_acc:3.015988349914551
epoch:50,train_loss:3.417860299172486, train_acc:3.6088509559631348, val_loss:3.4190574967583944, val_acc:3.1734495162963867
epoch:51,train_loss:3.416621050185706, train_acc:3.137327194213867, val_loss:3.4166778298311455, val_acc:3.391472816467285
epoch:52,train_loss:3.418241732219267, train_acc:3.534886360168457, val_loss:3.417380976122479, val_acc:2.579941749572754
epoch:53,train_loss:3.4201752236608924, train_acc:3.161982297897339, val_loss:3.4176531337028324, val_acc:2.616279125213623
epoch:54,train_loss:3.417625130986321, train_acc:3.5965237617492676, val_loss:3.4332177361776663, val_acc:3.3066859245300293
epoch:55,train_loss:3.418752680163412, train_acc:3.094181537628174, val_loss:3.4256202620129255, val_acc:3.3793604373931885
epoch:56,train_loss:3.4195547611755734, train_acc:3.6612424850463867, val_loss:3.42174948093503, val_acc:3.4156975746154785
epoch:57,train_loss:3.4177634631388285, train_acc:3.23286509513855, val_loss:3.4214947944463687, val_acc:2.616279125213623
epoch:58,train_loss:3.421686240201871, train_acc:2.9770710468292236, val_loss:3.411846914956736, val_acc:3.4520349502563477
epoch:59,train_loss:3.4174652720344136, train_acc:3.5410501956939697, val_loss:3.4218920275222424, val_acc:2.8706395626068115
epoch:60,train_loss:3.420347655313255, train_acc:3.245192289352417, val_loss:3.4128389635751413, val_acc:2.9796512126922607
epoch:61,train_loss:3.4182452534782817, train_acc:2.9400887489318848, val_loss:3.422146891438684, val_acc:3.3430233001708984
epoch:62,train_loss:3.415685522485767, train_acc:3.2513561248779297, val_loss:3.4185479352640553, val_acc:3.4156975746154785
epoch:63,train_loss:3.417576566955747, train_acc:3.094181537628174, val_loss:3.4088850853055024, val_acc:3.234011650085449
epoch:64,train_loss:3.4182081278964613, train_acc:2.866124153137207, val_loss:3.416826420052107, val_acc:3.318798303604126
epoch:65,train_loss:3.418849708060541, train_acc:3.2051284313201904, val_loss:3.408453292624895, val_acc:3.3430233001708984
epoch:66,train_loss:3.4193298774358083, train_acc:3.226701259613037, val_loss:3.4256949535636014, val_acc:2.9796512126922607
epoch:67,train_loss:3.417695463056395, train_acc:3.3099112510681152, val_loss:3.415983948596688, val_acc:3.4641470909118652
epoch:68,train_loss:3.416713549540593, train_acc:3.260601758956909, val_loss:3.413653795109239, val_acc:3.318798303604126
epoch:69,train_loss:3.4188056514107967, train_acc:3.1650640964508057, val_loss:3.42025640398957, val_acc:3.439922571182251
epoch:70,train_loss:3.4203963942781708, train_acc:3.2544379234313965, val_loss:3.4184001656465752, val_acc:3.125
epoch:71,train_loss:3.4179979724996894, train_acc:3.174309492111206, val_loss:3.4074795856032263, val_acc:3.015988349914551
epoch:72,train_loss:3.417654975631533, train_acc:3.155818462371826, val_loss:3.417345950769824, val_acc:3.5004844665527344
epoch:73,train_loss:3.4162017861766927, train_acc:3.1681461334228516, val_loss:3.420057013977406, val_acc:3.3430233001708984
epoch:74,train_loss:3.419482875857833, train_acc:3.075690507888794, val_loss:3.4232967398887455, val_acc:3.3430233001708984
epoch:75,train_loss:3.4189944196734907, train_acc:3.229782819747925, val_loss:3.4140212092288706, val_acc:3.1371123790740967
epoch:76,train_loss:3.4205527813476926, train_acc:3.1465728282928467, val_loss:3.427143795545711, val_acc:2.737403154373169
epoch:77,train_loss:3.4187424592012485, train_acc:3.276010751724243, val_loss:3.4177852231402728, val_acc:2.579941749572754
epoch:78,train_loss:3.419346115292882, train_acc:3.3592207431793213, val_loss:3.4079314442568047, val_acc:3.4156975746154785
epoch:79,train_loss:3.419485705844044, train_acc:3.1804733276367188, val_loss:3.4205848228099733, val_acc:3.3430233001708984
epoch:80,train_loss:3.4190032087134186, train_acc:3.2852563858032227, val_loss:3.4090217435082724, val_acc:3.427809953689575
epoch:81,train_loss:3.418792549675033, train_acc:3.1804733276367188, val_loss:3.411422552064408, val_acc:3.3793604373931885
epoch:82,train_loss:3.4182332690650896, train_acc:3.00788950920105, val_loss:3.4216125122336454, val_acc:3.4156975746154785
epoch:83,train_loss:3.417820659614879, train_acc:3.2575197219848633, val_loss:3.411681385927422, val_acc:3.948643445968628
epoch:84,train_loss:3.4190039352552426, train_acc:3.0633628368377686, val_loss:3.412457349688508, val_acc:2.8706395626068115
epoch:85,train_loss:3.418832723910992, train_acc:2.9462523460388184, val_loss:3.414388662160829, val_acc:2.579941749572754
epoch:86,train_loss:3.4185697750226987, train_acc:3.4085307121276855, val_loss:3.416147836419039, val_acc:3.1371123790740967
epoch:87,train_loss:3.4195547343711175, train_acc:3.3438117504119873, val_loss:3.413477354271467, val_acc:3.16133713722229
epoch:88,train_loss:3.4165763332998966, train_acc:3.4763314723968506, val_loss:3.4125312040018483, val_acc:3.6458332538604736
epoch:89,train_loss:3.416477335980658, train_acc:3.294501781463623, val_loss:3.4151419373445733, val_acc:3.5004844665527344
epoch:90,train_loss:3.416105400175738, train_acc:3.41469407081604, val_loss:3.418804551279822, val_acc:3.4156975746154785
epoch:91,train_loss:3.419120620693681, train_acc:3.3191568851470947, val_loss:3.4148293206858082, val_acc:3.015988349914551
epoch:92,train_loss:3.417863178535326, train_acc:3.433185338973999, val_loss:3.4063180546427883, val_acc:3.5247092247009277
epoch:93,train_loss:3.4200609520342224, train_acc:3.3253207206726074, val_loss:3.413665455441142, val_acc:3.015988349914551
epoch:94,train_loss:3.4180003157734165, train_acc:3.1003451347351074, val_loss:3.416378331738849, val_acc:3.779069662094116
epoch:95,train_loss:3.421211277944802, train_acc:3.174309492111206, val_loss:3.416663186494694, val_acc:3.5368216037750244
epoch:96,train_loss:3.4179833626606055, train_acc:3.109590530395508, val_loss:3.41137527310571, val_acc:3.561046600341797
epoch:97,train_loss:3.4203324515438642, train_acc:3.23286509513855, val_loss:3.412340297255405, val_acc:3.3430233001708984
epoch:98,train_loss:3.419171688824716, train_acc:3.1989645957946777, val_loss:3.4238855949667997, val_acc:2.9796512126922607
epoch:99,train_loss:3.4177519950640978, train_acc:3.276010751724243, val_loss:3.417204640632452, val_acc:3.4641470909118652
epoch:100,train_loss:3.4182146978096144, train_acc:3.2667651176452637, val_loss:3.4184995030247887, val_acc:3.3066859245300293
In [33]:
CNN_model_Improved_0001 =CNN_Class_Improved()

CNN_model_Improved_0001 = CNN_model_Improved_0001.to(device)
train_loss_lr_0001 = []
train_acc_lr_0001 = []
validate_loss_lr_0001 = []
validate_acc_lr_0001 = []
optimizer = torch.optim.Adam(CNN_model_Improved_0001.parameters(), lr=0.001)


for epoch in range(nepochs):
    train_running_loss , train_running_accuracy = train(train_loader, CNN_model_Improved_0001, criterion, optimizer)
    train_loss_lr_0001.append(train_running_loss)
    train_acc_lr_0001.append(train_running_accuracy)
    validate_running_loss , validate_running_accuracy = validate(validate_loader, CNN_model_Improved_0001, criterion, optimizer)
    validate_loss_lr_0001.append(validate_running_loss)
    validate_acc_lr_0001.append(validate_running_accuracy)

    if validate_running_loss < best_loss:
        best_loss = validate_running_loss
        torch.save(CNN_model_Improved_0001.state_dict(), './CNN_model_Improved_0001.pt')
    print('epoch:{},train_loss:{}, train_acc:{}, val_loss:{}, val_acc:{}'.format(epoch+1,train_running_loss,train_running_accuracy,validate_running_loss,validate_running_accuracy))
epoch:1,train_loss:3.1911271859908243, train_acc:9.544502258300781, val_loss:2.986053084218225, val_acc:15.23740291595459
epoch:2,train_loss:2.916512415959285, train_acc:16.854660034179688, val_loss:2.801635553670484, val_acc:19.28294563293457
epoch:3,train_loss:2.736940922821767, train_acc:21.859590530395508, val_loss:2.682002843812455, val_acc:23.897769927978516
epoch:4,train_loss:2.620152408554709, train_acc:24.63326072692871, val_loss:2.6289401220720867, val_acc:26.392927169799805
epoch:5,train_loss:2.5206565207983616, train_acc:27.11107063293457, val_loss:2.5242844792299493, val_acc:26.320253372192383
epoch:6,train_loss:2.450798304123286, train_acc:29.262203216552734, val_loss:2.4371402125025905, val_acc:30.365793228149414
epoch:7,train_loss:2.4043825702554376, train_acc:30.544254302978516, val_loss:2.3683269938757254, val_acc:31.855619430541992
epoch:8,train_loss:2.3393820460731463, train_acc:31.992727279663086, val_loss:2.358646592428518, val_acc:32.65503692626953
epoch:9,train_loss:2.3059581927293857, train_acc:32.68922424316406, val_loss:2.31296735586122, val_acc:34.338661193847656
epoch:10,train_loss:2.2608506270414273, train_acc:34.310279846191406, val_loss:2.29988732448844, val_acc:34.14486312866211
epoch:11,train_loss:2.2049239909155127, train_acc:35.09307098388672, val_loss:2.272709147874699, val_acc:34.39922332763672
epoch:12,train_loss:2.1789202817092987, train_acc:36.04844665527344, val_loss:2.2244559609612753, val_acc:35.804264068603516
epoch:13,train_loss:2.1175631799641446, train_acc:38.06089782714844, val_loss:2.192433759223583, val_acc:35.84060287475586
epoch:14,train_loss:2.0764113967940654, train_acc:39.78057098388672, val_loss:2.1785506315009537, val_acc:36.19186019897461
epoch:15,train_loss:2.058042284299636, train_acc:38.59097671508789, val_loss:2.1554948840030406, val_acc:38.50532913208008
epoch:16,train_loss:2.046162011355338, train_acc:39.715850830078125, val_loss:2.1306101277817127, val_acc:37.89970779418945
epoch:17,train_loss:2.001729795918662, train_acc:40.51405334472656, val_loss:2.1323553185130275, val_acc:37.89970779418945
epoch:18,train_loss:1.9814729225000687, train_acc:41.58654022216797, val_loss:2.117563998976419, val_acc:38.26308059692383
epoch:19,train_loss:1.9500723964363866, train_acc:42.42788314819336, val_loss:2.10549044054608, val_acc:39.21996307373047
epoch:20,train_loss:1.9504304912668713, train_acc:42.594303131103516, val_loss:2.1764658024144725, val_acc:37.97238540649414
epoch:21,train_loss:1.9215907302833872, train_acc:43.044254302978516, val_loss:2.087116565815238, val_acc:40.164730072021484
epoch:22,train_loss:1.907182706883673, train_acc:43.7376708984375, val_loss:2.1189985164376193, val_acc:39.35319900512695
epoch:23,train_loss:1.8836283514485557, train_acc:43.41407775878906, val_loss:2.0775403366532434, val_acc:41.085269927978516
epoch:24,train_loss:1.8656544762955616, train_acc:44.12598419189453, val_loss:2.0783199127330336, val_acc:39.80135726928711
epoch:25,train_loss:1.8698474888265486, train_acc:44.2862434387207, val_loss:2.1137201092963993, val_acc:39.24418640136719
epoch:26,train_loss:1.8349722581502248, train_acc:45.55288314819336, val_loss:2.1221803953481273, val_acc:40.261627197265625
epoch:27,train_loss:1.8137300980867013, train_acc:45.294010162353516, val_loss:2.0796976006308268, val_acc:40.007266998291016
epoch:28,train_loss:1.8014230580019528, train_acc:46.39114761352539, val_loss:2.030866586884787, val_acc:43.24127960205078
epoch:29,train_loss:1.785078847902061, train_acc:46.9612922668457, val_loss:2.035461370335069, val_acc:41.0610466003418
epoch:30,train_loss:1.7569422474979648, train_acc:47.583824157714844, val_loss:2.068865886954374, val_acc:41.569766998291016
epoch:31,train_loss:1.737341010358912, train_acc:48.243343353271484, val_loss:2.073564939720686, val_acc:40.45542526245117
epoch:32,train_loss:1.7261173259577103, train_acc:47.65779113769531, val_loss:2.045435675354891, val_acc:40.285850524902344
epoch:33,train_loss:1.7337800271412325, train_acc:47.53143310546875, val_loss:2.072408454362736, val_acc:40.34641647338867
epoch:34,train_loss:1.6981979968279777, train_acc:48.841224670410156, val_loss:2.0786357868549437, val_acc:41.315406799316406
epoch:35,train_loss:1.6865440846900264, train_acc:48.918270111083984, val_loss:2.0532170035118282, val_acc:40.92781066894531
epoch:36,train_loss:1.6882125315581553, train_acc:48.96141815185547, val_loss:2.0575836924619453, val_acc:41.35174560546875
epoch:37,train_loss:1.6592367619452393, train_acc:49.6240119934082, val_loss:2.055528163909912, val_acc:41.18217086791992
epoch:38,train_loss:1.671578735289489, train_acc:49.725711822509766, val_loss:2.0572038029515465, val_acc:41.812015533447266
epoch:39,train_loss:1.6578207319304787, train_acc:50.015411376953125, val_loss:2.097941980805508, val_acc:40.44331359863281
epoch:40,train_loss:1.6344125976223918, train_acc:50.66876220703125, val_loss:2.0761267224023507, val_acc:40.60077667236328
epoch:41,train_loss:1.6378633284709863, train_acc:50.30818176269531, val_loss:2.0407898731009904, val_acc:41.63032913208008
epoch:42,train_loss:1.6259075494912953, train_acc:50.397560119628906, val_loss:2.083180918249973, val_acc:42.405521392822266
epoch:43,train_loss:1.6108569266528068, train_acc:50.56089782714844, val_loss:2.0441354041875797, val_acc:42.708335876464844
epoch:44,train_loss:1.5851273409713655, train_acc:51.35909652709961, val_loss:2.0792733985324237, val_acc:41.42441940307617
epoch:45,train_loss:1.5764018311303043, train_acc:51.550174713134766, val_loss:2.0797172701636026, val_acc:41.89680099487305
epoch:46,train_loss:1.5714580349668243, train_acc:51.9138298034668, val_loss:2.092423987943073, val_acc:41.87257766723633
epoch:47,train_loss:1.556293245603347, train_acc:52.77983474731445, val_loss:2.0449386807375176, val_acc:42.51453399658203
epoch:48,train_loss:1.5329540262560872, train_acc:53.374629974365234, val_loss:2.0446904620458914, val_acc:42.21172332763672
epoch:49,train_loss:1.5542095405815621, train_acc:52.42233657836914, val_loss:2.0868139377860135, val_acc:41.44864273071289
epoch:50,train_loss:1.517126593364061, train_acc:53.818416595458984, val_loss:2.0363581790480505, val_acc:43.75
epoch:51,train_loss:1.5025355089345627, train_acc:54.385475158691406, val_loss:2.1164575343908267, val_acc:41.96947479248047
epoch:52,train_loss:1.528262742877712, train_acc:53.25443649291992, val_loss:2.0643693070079006, val_acc:42.41763687133789
epoch:53,train_loss:1.5132660223887517, train_acc:54.07421112060547, val_loss:2.0913036978522013, val_acc:42.151161193847656
epoch:54,train_loss:1.4974592740719135, train_acc:54.203651428222656, val_loss:2.0812370583068494, val_acc:42.562984466552734
epoch:55,train_loss:1.4779720920077442, train_acc:54.92171859741211, val_loss:2.1356683553651323, val_acc:41.99370193481445
epoch:56,train_loss:1.4970618620426697, train_acc:54.024898529052734, val_loss:2.1307524609011272, val_acc:40.964149475097656
epoch:57,train_loss:1.4814021880104697, train_acc:55.26072311401367, val_loss:2.053830152334169, val_acc:41.63032913208008
epoch:58,train_loss:1.4547195914228992, train_acc:55.48878479003906, val_loss:2.0846458839815716, val_acc:43.374515533447266
epoch:59,train_loss:1.45824292072883, train_acc:55.69526672363281, val_loss:2.080652666646381, val_acc:42.11482620239258
epoch:60,train_loss:1.4478308946423277, train_acc:55.88634490966797, val_loss:2.070925468622252, val_acc:42.65988540649414
epoch:61,train_loss:1.4465424679440153, train_acc:55.54733657836914, val_loss:2.0982873384342637, val_acc:40.625
epoch:62,train_loss:1.417205399310095, train_acc:56.055843353271484, val_loss:2.0567419307176458, val_acc:42.11482620239258
epoch:63,train_loss:1.4109138677106101, train_acc:56.74618148803711, val_loss:2.0790942436040836, val_acc:41.19428253173828
epoch:64,train_loss:1.4270178206573576, train_acc:56.21609878540039, val_loss:2.0332969399385674, val_acc:44.2344970703125
epoch:65,train_loss:1.4109003572068977, train_acc:57.066688537597656, val_loss:2.13479960519214, val_acc:39.87403106689453
epoch:66,train_loss:1.4048020120203142, train_acc:57.44267654418945, val_loss:2.026360500690549, val_acc:44.40406799316406
epoch:67,train_loss:1.4019899685707318, train_acc:57.31940460205078, val_loss:2.040835521941961, val_acc:41.7514533996582
epoch:68,train_loss:1.4054638676389435, train_acc:56.81398010253906, val_loss:2.1204093888748523, val_acc:42.38129806518555
epoch:69,train_loss:1.3860898419950136, train_acc:57.44267654418945, val_loss:2.092866908672244, val_acc:43.544090270996094
epoch:70,train_loss:1.3722516675672587, train_acc:57.73545455932617, val_loss:2.1119466715080795, val_acc:42.97480392456055
epoch:71,train_loss:1.390832188566761, train_acc:57.75394821166992, val_loss:2.059696294540583, val_acc:43.02325439453125
epoch:72,train_loss:1.3767634534976891, train_acc:57.73545455932617, val_loss:2.069774339365405, val_acc:43.47141647338867
epoch:73,train_loss:1.3583349377445921, train_acc:58.14841842651367, val_loss:2.1562981633252876, val_acc:43.386627197265625
epoch:74,train_loss:1.3378441129210432, train_acc:58.44119644165039, val_loss:2.08634912690451, val_acc:44.367733001708984
epoch:75,train_loss:1.3603977222414412, train_acc:57.89878845214844, val_loss:2.085165400837743, val_acc:42.84156799316406
epoch:76,train_loss:1.332293771427764, train_acc:58.49050521850586, val_loss:2.1116821516391844, val_acc:42.35707092285156
epoch:77,train_loss:1.3401080881350138, train_acc:58.76787185668945, val_loss:2.1042044329088787, val_acc:42.79311752319336
epoch:78,train_loss:1.3381292012316235, train_acc:58.79253005981445, val_loss:2.0878591482029405, val_acc:44.682655334472656
epoch:79,train_loss:1.3298471570014954, train_acc:59.68010330200195, val_loss:2.0947961751804796, val_acc:42.92635726928711
epoch:80,train_loss:1.3317312982660778, train_acc:59.16851806640625, val_loss:2.1222938964533253, val_acc:42.50242233276367
epoch:81,train_loss:1.3220863324650647, train_acc:59.66161346435547, val_loss:2.072736096936603, val_acc:42.50242233276367
epoch:82,train_loss:1.3219773127482488, train_acc:59.32569122314453, val_loss:2.14878735431405, val_acc:41.48497772216797
epoch:83,train_loss:1.306615230247114, train_acc:59.547584533691406, val_loss:2.13507309902546, val_acc:41.90891647338867
epoch:84,train_loss:1.3059152470537896, train_acc:59.29795455932617, val_loss:2.1891779594643173, val_acc:42.53875732421875
epoch:85,train_loss:1.2921045835201557, train_acc:59.76023483276367, val_loss:2.1152667971544488, val_acc:42.042152404785156
epoch:86,train_loss:1.2735567209283276, train_acc:60.18552780151367, val_loss:2.128107009932052, val_acc:43.2655029296875
epoch:87,train_loss:1.2963687403667608, train_acc:60.3611946105957, val_loss:2.165547340415245, val_acc:42.28439712524414
epoch:88,train_loss:1.2947157050025533, train_acc:60.23483657836914, val_loss:2.109045857606932, val_acc:43.713661193847656
epoch:89,train_loss:1.2850948854310977, train_acc:60.34270477294922, val_loss:2.1421401417532633, val_acc:43.483524322509766
epoch:90,train_loss:1.2887458261653517, train_acc:60.17628479003906, val_loss:2.1347553785457167, val_acc:41.95736312866211
epoch:91,train_loss:1.2774624993815225, train_acc:60.15470886230469, val_loss:2.111482395682224, val_acc:41.71511459350586
epoch:92,train_loss:1.258455208420048, train_acc:61.09159469604492, val_loss:2.1545954942703247, val_acc:43.21705627441406
epoch:93,train_loss:1.2406606804689713, train_acc:61.153228759765625, val_loss:2.1607651239217716, val_acc:43.21705627441406
epoch:94,train_loss:1.2692347514558826, train_acc:60.57384490966797, val_loss:2.127889716347983, val_acc:43.21705627441406
epoch:95,train_loss:1.2696357436434051, train_acc:60.5861701965332, val_loss:2.143072849096254, val_acc:42.902130126953125
epoch:96,train_loss:1.2370165519460419, train_acc:61.90520477294922, val_loss:2.181827423184417, val_acc:42.10271072387695
epoch:97,train_loss:1.25053066791162, train_acc:61.60626220703125, val_loss:2.176362212314162, val_acc:43.77422332763672
epoch:98,train_loss:1.2255728777343704, train_acc:62.244205474853516, val_loss:2.1287347937739174, val_acc:41.65455627441406
epoch:99,train_loss:1.223124155278742, train_acc:62.13325881958008, val_loss:2.218731730483299, val_acc:41.60610580444336
epoch:100,train_loss:1.2221639054061393, train_acc:61.809661865234375, val_loss:2.144239292588345, val_acc:42.47819900512695
In [28]:
CNN_model_Improved_00001 =CNN_Class_Improved()

CNN_model_Improved_00001 = CNN_model_Improved_00001.to(device)
train_loss_lr_00001 = []
train_acc_lr_00001 = []
validate_loss_lr_00001 = []
validate_acc_lr_00001 = []
optimizer = torch.optim.Adam(CNN_model_Improved_00001.parameters(), lr=0.0001)

nepochs=100
for epoch in range(nepochs):
    train_running_loss , train_running_accuracy = train(train_loader, CNN_model_Improved_00001, criterion, optimizer)
    train_loss_lr_00001.append(train_running_loss)
    train_acc_lr_00001.append(train_running_accuracy)
    validate_running_loss , validate_running_accuracy = validate(validate_loader, CNN_model_Improved_00001, criterion, optimizer)
    validate_loss_lr_00001.append(validate_running_loss)
    validate_acc_lr_00001.append(validate_running_accuracy)

    if validate_running_loss < best_loss:
        best_loss = validate_running_loss
        torch.save(CNN_model_Improved_00001.state_dict(), './CNN_model_Improved_00001.pt')
    print('epoch:{},train_loss:{}, train_acc:{}, val_loss:{}, val_acc:{}'.format(epoch+1,train_running_loss,train_running_accuracy,validate_running_loss,validate_running_accuracy))
epoch:1,train_loss:3.2796755302587206, train_acc:8.444280624389648, val_loss:3.1328728864359303, val_acc:12.051841735839844
epoch:2,train_loss:3.066263253872211, train_acc:13.929981231689453, val_loss:3.001076387804608, val_acc:15.964146614074707
epoch:3,train_loss:2.941012852290678, train_acc:17.239891052246094, val_loss:2.876550230869027, val_acc:18.265504837036133
epoch:4,train_loss:2.851026439102444, train_acc:19.125986099243164, val_loss:2.810957442882449, val_acc:20.32461166381836
epoch:5,train_loss:2.787592075280184, train_acc:20.972015380859375, val_loss:2.7556045831635942, val_acc:21.790212631225586
epoch:6,train_loss:2.7347744526947744, train_acc:22.343442916870117, val_loss:2.7041152854298436, val_acc:22.953004837036133
epoch:7,train_loss:2.6939765955569475, train_acc:23.6624755859375, val_loss:2.692073761030685, val_acc:22.347383499145508
epoch:8,train_loss:2.6639692994969835, train_acc:24.33740234375, val_loss:2.6464970333631648, val_acc:25.557170867919922
epoch:9,train_loss:2.634774837268175, train_acc:24.919872283935547, val_loss:2.6104013919830322, val_acc:26.392927169799805
epoch:10,train_loss:2.593165810996964, train_acc:26.010848999023438, val_loss:2.6000261362208876, val_acc:25.968990325927734
epoch:11,train_loss:2.569996304765961, train_acc:26.919994354248047, val_loss:2.575469815453818, val_acc:25.121124267578125
epoch:12,train_loss:2.53580813153961, train_acc:27.671966552734375, val_loss:2.5255050603733507, val_acc:28.37936019897461
epoch:13,train_loss:2.5142372176492, train_acc:27.85687828063965, val_loss:2.5231597811676734, val_acc:28.10077667236328
epoch:14,train_loss:2.484837282338791, train_acc:28.762943267822266, val_loss:2.512596291165019, val_acc:28.972869873046875
epoch:15,train_loss:2.4733378957714556, train_acc:29.570390701293945, val_loss:2.4739946099214776, val_acc:30.1719970703125
epoch:16,train_loss:2.457911544297574, train_acc:29.219058990478516, val_loss:2.4756919561430464, val_acc:29.832849502563477
epoch:17,train_loss:2.4347294315078556, train_acc:29.893985748291016, val_loss:2.4439604947733327, val_acc:30.3900203704834
epoch:18,train_loss:2.412981839575006, train_acc:31.135971069335938, val_loss:2.4274158671844837, val_acc:32.46124267578125
epoch:19,train_loss:2.3984049692661804, train_acc:31.274654388427734, val_loss:2.413873550503753, val_acc:31.782943725585938
epoch:20,train_loss:2.367776843217703, train_acc:32.33481216430664, val_loss:2.39947415507117, val_acc:31.31056022644043
epoch:21,train_loss:2.357298065219405, train_acc:31.921842575073242, val_loss:2.3900515478710798, val_acc:30.95930290222168
epoch:22,train_loss:2.333338290276612, train_acc:32.633750915527344, val_loss:2.3940046221710913, val_acc:32.30377960205078
epoch:23,train_loss:2.3160707689601288, train_acc:33.906558990478516, val_loss:2.393957775692607, val_acc:31.649709701538086
epoch:24,train_loss:2.3079099768012235, train_acc:33.56755447387695, val_loss:2.34613381984622, val_acc:32.63081359863281
epoch:25,train_loss:2.296607259462571, train_acc:33.58296585083008, val_loss:2.3192355078320173, val_acc:33.34544372558594
epoch:26,train_loss:2.2738238241545545, train_acc:34.26097106933594, val_loss:2.332012026808983, val_acc:33.406009674072266
epoch:27,train_loss:2.253186898824026, train_acc:35.20709991455078, val_loss:2.315386486607929, val_acc:33.987403869628906
epoch:28,train_loss:2.250527298662084, train_acc:35.1886100769043, val_loss:2.335046102834302, val_acc:32.945735931396484
epoch:29,train_loss:2.2227693415252414, train_acc:35.265655517578125, val_loss:2.288711567257726, val_acc:33.90261459350586
epoch:30,train_loss:2.220126310749167, train_acc:35.99913787841797, val_loss:2.3039729567461236, val_acc:34.56879806518555
epoch:31,train_loss:2.2153540913169905, train_acc:36.15323257446289, val_loss:2.293829879095388, val_acc:34.68992233276367
epoch:32,train_loss:2.1947776474190888, train_acc:36.24260330200195, val_loss:2.2992831524028334, val_acc:34.02374267578125
epoch:33,train_loss:2.177062145351658, train_acc:36.9760856628418, val_loss:2.268732553304628, val_acc:36.36143112182617
epoch:34,train_loss:2.1558406670418013, train_acc:37.090110778808594, val_loss:2.2572834325391193, val_acc:35.86482620239258
epoch:35,train_loss:2.1473963056090315, train_acc:37.60786437988281, val_loss:2.2591070408044858, val_acc:36.19186019897461
epoch:36,train_loss:2.146941250597937, train_acc:37.04080581665039, val_loss:2.2474963692731635, val_acc:35.99806213378906
epoch:37,train_loss:2.1229546514488535, train_acc:38.76972579956055, val_loss:2.2526982856351276, val_acc:36.143409729003906
epoch:38,train_loss:2.1176861496366692, train_acc:38.396820068359375, val_loss:2.22594033008398, val_acc:36.66424560546875
epoch:39,train_loss:2.1076269269694943, train_acc:38.88683319091797, val_loss:2.2588198295859403, val_acc:36.5794563293457
epoch:40,train_loss:2.108998182962632, train_acc:38.71425247192383, val_loss:2.2012144382609877, val_acc:37.294090270996094
epoch:41,train_loss:2.0898374890434672, train_acc:39.087154388427734, val_loss:2.1979586246401763, val_acc:36.48255920410156
epoch:42,train_loss:2.0901577211696014, train_acc:38.57864761352539, val_loss:2.2086341076119003, val_acc:36.02228927612305
epoch:43,train_loss:2.0680763940133993, train_acc:39.318294525146484, val_loss:2.191249215325644, val_acc:35.828487396240234
epoch:44,train_loss:2.054594759405012, train_acc:39.712772369384766, val_loss:2.1973287033480267, val_acc:36.5794563293457
epoch:45,train_loss:2.039196566011779, train_acc:40.13806915283203, val_loss:2.1795612019161847, val_acc:37.28197479248047
epoch:46,train_loss:2.0301347343173957, train_acc:40.52946090698242, val_loss:2.2055412974468496, val_acc:35.804264068603516
epoch:47,train_loss:2.014961471924415, train_acc:40.94859313964844, val_loss:2.1725045913873715, val_acc:38.674903869628906
epoch:48,train_loss:2.0152080729162907, train_acc:41.60503005981445, val_loss:2.1827113212541094, val_acc:38.468990325927734
epoch:49,train_loss:2.003793822237726, train_acc:40.79450225830078, val_loss:2.159993235455003, val_acc:39.4985466003418
epoch:50,train_loss:1.9894032076265684, train_acc:42.11045455932617, val_loss:2.171448075494101, val_acc:37.14874267578125
epoch:51,train_loss:1.970363364417172, train_acc:42.51109313964844, val_loss:2.138068362723949, val_acc:39.7044563293457
epoch:52,train_loss:1.9762740981649365, train_acc:41.777610778808594, val_loss:2.143760695013889, val_acc:38.45688247680664
epoch:53,train_loss:1.9773973748528746, train_acc:41.722137451171875, val_loss:2.1409021532812784, val_acc:38.941375732421875
epoch:54,train_loss:1.9595034891331689, train_acc:42.05806350708008, val_loss:2.1526396108228107, val_acc:38.48110580444336
epoch:55,train_loss:1.942931172410412, train_acc:43.14287567138672, val_loss:2.129172862962235, val_acc:39.60755920410156
epoch:56,train_loss:1.9418711246118037, train_acc:42.803871154785156, val_loss:2.1091779359551364, val_acc:40.007266998291016
epoch:57,train_loss:1.9373988230552899, train_acc:42.785377502441406, val_loss:2.13331177345542, val_acc:38.32364273071289
epoch:58,train_loss:1.9329443831415571, train_acc:43.19834899902344, val_loss:2.1204048672387765, val_acc:39.83769607543945
epoch:59,train_loss:1.918922458174666, train_acc:43.85786437988281, val_loss:2.1178600788116455, val_acc:39.03827667236328
epoch:60,train_loss:1.8999799595782036, train_acc:44.09208679199219, val_loss:2.1590126863745756, val_acc:38.32364273071289
epoch:61,train_loss:1.8980442722873574, train_acc:44.10441589355469, val_loss:2.1230733616407527, val_acc:39.38953399658203
epoch:62,train_loss:1.901219623328666, train_acc:43.84553909301758, val_loss:2.1353326814119207, val_acc:39.171512603759766
epoch:63,train_loss:1.8761565579465156, train_acc:44.77317428588867, val_loss:2.1243791691092557, val_acc:39.1109504699707
epoch:64,train_loss:1.8660558187044585, train_acc:45.35564422607422, val_loss:2.147592017816943, val_acc:38.31153106689453
epoch:65,train_loss:1.8696399419265386, train_acc:45.130672454833984, val_loss:2.092461014902869, val_acc:40.176841735839844
epoch:66,train_loss:1.8564818935281426, train_acc:44.95808410644531, val_loss:2.090336184168971, val_acc:40.104164123535156
epoch:67,train_loss:1.8503782375324407, train_acc:45.485084533691406, val_loss:2.0790175005447034, val_acc:40.31007766723633
epoch:68,train_loss:1.8471501198040663, train_acc:45.42344665527344, val_loss:2.120376140572304, val_acc:40.5765495300293
epoch:69,train_loss:1.8262064047819058, train_acc:46.02440643310547, val_loss:2.0704417395037273, val_acc:40.72189712524414
epoch:70,train_loss:1.8295576847516573, train_acc:46.14459991455078, val_loss:2.0638344759164853, val_acc:41.15794372558594
epoch:71,train_loss:1.8096729601628683, train_acc:46.11994552612305, val_loss:2.0690932162972384, val_acc:40.89147186279297
epoch:72,train_loss:1.8182336578707723, train_acc:46.54524230957031, val_loss:2.0709913081901017, val_acc:40.87936019897461
epoch:73,train_loss:1.8036771645912757, train_acc:46.465110778808594, val_loss:2.1412681368894355, val_acc:39.401649475097656
epoch:74,train_loss:1.7917747934894448, train_acc:46.28944778442383, val_loss:2.067728890929111, val_acc:40.5765495300293
epoch:75,train_loss:1.7835359326481113, train_acc:47.377342224121094, val_loss:2.106135518051857, val_acc:39.83769607543945
epoch:76,train_loss:1.786717737214805, train_acc:46.85959243774414, val_loss:2.0853390056033465, val_acc:41.218509674072266
epoch:77,train_loss:1.7747366879818707, train_acc:47.759490966796875, val_loss:2.0839045491329458, val_acc:41.145835876464844
epoch:78,train_loss:1.7654184201765342, train_acc:47.54376220703125, val_loss:2.0704863293226374, val_acc:41.497093200683594
epoch:79,train_loss:1.7546812049030551, train_acc:47.61464309692383, val_loss:2.071953992510951, val_acc:41.315406799316406
epoch:80,train_loss:1.7469937568585547, train_acc:47.716346740722656, val_loss:2.0678959364114804, val_acc:42.42974853515625
epoch:81,train_loss:1.7372189950660841, train_acc:48.283409118652344, val_loss:2.0594216058420582, val_acc:42.27228927612305
epoch:82,train_loss:1.7496934127525465, train_acc:48.2125244140625, val_loss:2.0404803475668265, val_acc:42.75678253173828
epoch:83,train_loss:1.7196281957908495, train_acc:48.34812545776367, val_loss:2.0440582896387856, val_acc:41.37596893310547
epoch:84,train_loss:1.720940080620128, train_acc:48.579261779785156, val_loss:2.083317277043365, val_acc:39.84980392456055
epoch:85,train_loss:1.7203708492087189, train_acc:48.83197784423828, val_loss:2.060117527495983, val_acc:42.393409729003906
epoch:86,train_loss:1.7112735199505056, train_acc:48.94292449951172, val_loss:2.0421820490859277, val_acc:42.09060287475586
epoch:87,train_loss:1.6973145960350713, train_acc:49.46067810058594, val_loss:2.0402699514876965, val_acc:42.01792526245117
epoch:88,train_loss:1.7020337757979624, train_acc:49.44218444824219, val_loss:2.08253134960352, val_acc:41.0610466003418
epoch:89,train_loss:1.6742494550682383, train_acc:50.42221450805664, val_loss:2.0706246176431344, val_acc:41.64244079589844
epoch:90,train_loss:1.676584166182569, train_acc:49.953773498535156, val_loss:2.073023798853852, val_acc:41.17005920410156
epoch:91,train_loss:1.6772573897119105, train_acc:50.32051467895508, val_loss:2.0418946632119113, val_acc:41.88468933105469
epoch:92,train_loss:1.6651838601693598, train_acc:50.431461334228516, val_loss:2.0438435743021413, val_acc:41.46075439453125
epoch:93,train_loss:1.6536222773896168, train_acc:50.400638580322266, val_loss:2.0570213101630985, val_acc:41.37596893310547
epoch:94,train_loss:1.6471364462869407, train_acc:50.714988708496094, val_loss:2.053965829139532, val_acc:41.04893112182617
epoch:95,train_loss:1.631173533800791, train_acc:51.46696472167969, val_loss:2.039872352467027, val_acc:42.24806213378906
epoch:96,train_loss:1.633132583996248, train_acc:50.59479522705078, val_loss:2.0786331587059554, val_acc:40.21317672729492
epoch:97,train_loss:1.623470737383916, train_acc:51.40532684326172, val_loss:2.0779663573863894, val_acc:41.35174560546875
epoch:98,train_loss:1.6344603582246768, train_acc:51.14645004272461, val_loss:2.018333870311116, val_acc:43.82267379760742
epoch:99,train_loss:1.6260780554551344, train_acc:51.4484748840332, val_loss:2.0569261035253836, val_acc:41.254844665527344
epoch:100,train_loss:1.6022041372998932, train_acc:51.809051513671875, val_loss:2.0680011427679728, val_acc:41.96947479248047
In [36]:
# Your graph
x_axis = np.arange(1,nepochs+1,1,int)
fig,axs=plt.subplots(2,1,figsize=(15,20),sharex=True,sharey=False)
fig.suptitle('The loss and accuracy of train and validate sets in different learning rate')

axs[0].plot(x_axis[1:],train_loss_lr_01[1:],label='train_loss at 0.1 learning rate')
axs[0].plot(x_axis,validate_loss_lr_01,label='validate_loss at 0.1 learning rate')
axs[0].plot(x_axis,train_loss_lr_0001,label='train_loss at 0.001 learning rate')
axs[0].plot(x_axis,validate_loss_lr_0001,label='validate_loss at 0.001 learning rate')
axs[0].plot(x_axis,train_loss_lr_00001,label='train_loss at 0.0001 learning rate')
axs[0].plot(x_axis,validate_loss_lr_00001,label='validate_loss at 0.0001 learning rate')

axs[1].plot(x_axis,train_acc_lr_01,label='train_accuracy at 0.1 learning rate')
axs[1].plot(x_axis,validate_acc_lr_01,label='validate_accuracy at 0.1 learning rate')
axs[1].plot(x_axis,train_acc_lr_0001,label='train_accuracy at 0.001 learning rate')
axs[1].plot(x_axis,validate_acc_lr_0001,label='validate_accuracy at 0.001 learning rate')
axs[1].plot(x_axis,train_acc_lr_00001,label='train_accuracy at 0.0001 learning rate')
axs[1].plot(x_axis,validate_acc_lr_00001,label='validate_accuracy at 0.0001 learning rate')

axs[1].set_xlabel('epoch')
axs[0].set_ylabel('loss')
axs[1].set_ylabel('percentage of accuracy')
axs[0].legend()
axs[1].legend()
plt.show()

By comparing the three cases of learning rate of 0.1, 0.001, and 0.0001, it can be seen from the figure that when the learning rate is set to 0.001, the Loss can converge better and obtain higher accuracy, so we think the model CNN_model_Improved_0001 (learning rate of 0.001) is more accurate.

3 Model testing¶

Online evaluation of your model performance on the test set.

Prepare the dataloader for test set

Write evaluation code for writing predictions

Upload it to Kaggle submission page link

3.1 Test class and predictions¶

Build a test class, prepare a test dataloader and generate predictions

Create a PyTorch Dataset for the unlabeled test data in the test_set folder of the Kaggle competition and generate predictions using your final model. Test data can be downloaded here

In [34]:
# Your code here!
test_set = MyDataset("test_set")
test_loader = DataLoader(
    test_set,
    batch_size = 64,
    shuffle = False)

3.2 Prepare submission and upload to Kaggle¶

Save all test predictions to a CSV file and submit it to the private class Kaggle competition.

In [37]:
# Your code here! 
num_class = len(classes)
CNN_model_Improved_0001.load_state_dict(torch.load('./CNN_model_Improved_0001.pt'))
dic = {"Id":[],"Category":[]}
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        images = images.to(device)
        outputs = CNN_model_Improved_0001(images)
        labels = [l.replace("comp5625M_data_assessment_1\\test_set\\test_set\\","") for l in labels]
        pre_y = torch.max(outputs, dim=1)[1].cpu().numpy()
        dic["Id"].extend(labels)
        dic["Category"].extend(pre_y)
df = pd.DataFrame.from_dict(dic, orient='index').T
df.to_csv("ml21zw.csv",index = False)

4 Model Fine-tuning/transfer learning on CIFAR10 dataset¶

Fine-tuning is a way of applying or utilizing transfer learning. It is a process that takes a model that has already been trained for one task and then tunes or tweaks the model to make it perform a second similar task. You can perform finetuning in the following way:

  • Train an entire model from scratch (large dataset, more computation)
  • Freeze convolution base and train only last FC layers (small dataset and lower computation)

Configuring your dataset

  • Download your dataset using torchvision.datasets.CIFAR10, explained here
  • Split training dataset into training and validation set similar to above. Note that the number of categories here are only 10
In [3]:
# Your code here! 
transform = transforms.Compose(
  [ transforms.Resize((224,224)),
      transforms.ToTensor(),
 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

CIFAR10trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
length=len(CIFAR10trainset)
CIFAR10train_size,CIFAR10validate_size=int(0.8*length),int(0.2*length)
CIFAR10trainset,CIFAR10validateset=torch.utils.data.random_split(CIFAR10trainset,[CIFAR10train_size,CIFAR10validate_size],generator=torch.Generator().manual_seed(0))
print(len(CIFAR10trainset),len(CIFAR10validateset))

CIFAR10trainloader = torch.utils.data.DataLoader(CIFAR10trainset, batch_size=4,
  shuffle=True, num_workers=2)
CIFAR10validateloader = torch.utils.data.DataLoader(CIFAR10validateset, batch_size=4,
  shuffle=True, num_workers=2)
Files already downloaded and verified
40000 10000

Load pretrained AlexNet from PyTorch - use model copies to apply transfer learning in different configurations

In [16]:
# Your code here! 
import torchvision.models as models
alexnet = models.alexnet(pretrained=True)
num_fc = alexnet.classifier[6].in_features
alexnet.classifier[6] = torch.nn.Linear(in_features=num_fc, out_features=10)
alexnet = alexnet.to(device)
print(alexnet)
AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
    (2): ReLU(inplace=True)
    (3): Dropout(p=0.5, inplace=False)
    (4): Linear(in_features=4096, out_features=4096, bias=True)
    (5): ReLU(inplace=True)
    (6): Linear(in_features=4096, out_features=10, bias=True)
  )
)

4.1 Apply transfer learning initialise with pretrained model weights¶

Use pretrained weights from AlexNet only (on the right of figure) to initialise your model.

Figure: Two models are given here: LeNet and AlexNet for image classification. However, you have to use **only AlexNet**.

Configuration 1: No frozen layers

In [22]:
# Your model changes here - also print trainable parameters
total_params = sum(p.numel() for p in alexnet.parameters())
print(f'{total_params:,} total parameters.')
total_trainable_params = sum(
    p.numel() for p in alexnet.parameters() if p.requires_grad)
print(f'{total_trainable_params:,} trainable parameters.')
nepochs=100
optimizer = torch.optim.Adam(alexnet.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss()
alexnet_best_loss = 1000
alexnet_train_loss, alexnet_validate_loss, alexnet_train_accuracy, alexnet_validate_accuracy = [], [], [], []
for epoch in range(nepochs):
    alexnet_train_running_loss , alexnet_train_running_accuracy = train(CIFAR10trainloader, alexnet, criterion, optimizer)
    alexnet_train_loss.append(alexnet_train_running_loss)
    alexnet_train_accuracy.append(alexnet_train_running_accuracy)
    alexnet_validate_running_loss , alexnet_validate_running_accuracy = validate(CIFAR10validateloader, alexnet, criterion, optimizer)
    alexnet_validate_loss.append(alexnet_validate_running_loss)
    alexnet_validate_accuracy.append(alexnet_validate_running_accuracy)
    if alexnet_validate_running_loss < alexnet_best_loss:
        alexnet_best_loss = alexnet_validate_running_loss
        torch.save(alexnet.state_dict(), './alexnet.pt')
    print(f"epoch: {epoch+1} alexnet_train_loss: {alexnet_train_loss[epoch] : .3f} alexnet_train_accuracy: {alexnet_train_accuracy[epoch] : .3f} alexnet_validate_loss: {alexnet_validate_loss[epoch] : .3f} alexnet_validate_accuracy: {alexnet_validate_accuracy[epoch] : .3f}")  
57,044,810 total parameters.
57,044,810 trainable parameters.
epoch: 1 alexnet_train_loss:  0.394 alexnet_train_accuracy:  87.548 alexnet_validate_loss:  0.442 alexnet_validate_accuracy:  86.210
epoch: 2 alexnet_train_loss:  0.333 alexnet_train_accuracy:  89.198 alexnet_validate_loss:  0.464 alexnet_validate_accuracy:  85.180
epoch: 3 alexnet_train_loss:  0.308 alexnet_train_accuracy:  90.020 alexnet_validate_loss:  0.494 alexnet_validate_accuracy:  86.010
epoch: 4 alexnet_train_loss:  0.271 alexnet_train_accuracy:  91.365 alexnet_validate_loss:  0.498 alexnet_validate_accuracy:  84.900
epoch: 5 alexnet_train_loss:  0.255 alexnet_train_accuracy:  91.923 alexnet_validate_loss:  0.499 alexnet_validate_accuracy:  85.640
epoch: 6 alexnet_train_loss:  0.246 alexnet_train_accuracy:  92.272 alexnet_validate_loss:  0.531 alexnet_validate_accuracy:  85.600
epoch: 7 alexnet_train_loss:  0.238 alexnet_train_accuracy:  92.600 alexnet_validate_loss:  0.486 alexnet_validate_accuracy:  85.370
epoch: 8 alexnet_train_loss:  0.223 alexnet_train_accuracy:  93.005 alexnet_validate_loss:  0.592 alexnet_validate_accuracy:  86.620
epoch: 9 alexnet_train_loss:  0.215 alexnet_train_accuracy:  93.397 alexnet_validate_loss:  0.650 alexnet_validate_accuracy:  84.940
epoch: 10 alexnet_train_loss:  0.220 alexnet_train_accuracy:  93.335 alexnet_validate_loss:  0.453 alexnet_validate_accuracy:  86.120
epoch: 11 alexnet_train_loss:  0.211 alexnet_train_accuracy:  93.745 alexnet_validate_loss:  0.533 alexnet_validate_accuracy:  85.970
epoch: 12 alexnet_train_loss:  0.208 alexnet_train_accuracy:  93.832 alexnet_validate_loss:  0.462 alexnet_validate_accuracy:  86.240
epoch: 13 alexnet_train_loss:  0.217 alexnet_train_accuracy:  93.512 alexnet_validate_loss:  0.548 alexnet_validate_accuracy:  83.470
epoch: 14 alexnet_train_loss:  0.226 alexnet_train_accuracy:  93.488 alexnet_validate_loss:  0.528 alexnet_validate_accuracy:  83.600
epoch: 15 alexnet_train_loss:  0.228 alexnet_train_accuracy:  93.338 alexnet_validate_loss:  0.503 alexnet_validate_accuracy:  85.330
epoch: 16 alexnet_train_loss:  0.232 alexnet_train_accuracy:  93.255 alexnet_validate_loss:  0.514 alexnet_validate_accuracy:  84.990
epoch: 17 alexnet_train_loss:  0.201 alexnet_train_accuracy:  94.115 alexnet_validate_loss:  0.588 alexnet_validate_accuracy:  87.160
epoch: 18 alexnet_train_loss:  0.237 alexnet_train_accuracy:  93.277 alexnet_validate_loss:  0.554 alexnet_validate_accuracy:  85.870
epoch: 19 alexnet_train_loss:  0.220 alexnet_train_accuracy:  93.665 alexnet_validate_loss:  0.636 alexnet_validate_accuracy:  84.580
epoch: 20 alexnet_train_loss:  0.230 alexnet_train_accuracy:  93.592 alexnet_validate_loss:  0.625 alexnet_validate_accuracy:  85.690
epoch: 21 alexnet_train_loss:  0.223 alexnet_train_accuracy:  93.670 alexnet_validate_loss:  0.734 alexnet_validate_accuracy:  84.950
epoch: 22 alexnet_train_loss:  0.228 alexnet_train_accuracy:  93.442 alexnet_validate_loss:  0.587 alexnet_validate_accuracy:  86.120
epoch: 23 alexnet_train_loss:  0.236 alexnet_train_accuracy:  93.415 alexnet_validate_loss:  0.563 alexnet_validate_accuracy:  86.530
epoch: 24 alexnet_train_loss:  0.232 alexnet_train_accuracy:  93.308 alexnet_validate_loss:  0.659 alexnet_validate_accuracy:  85.700
epoch: 25 alexnet_train_loss:  0.309 alexnet_train_accuracy:  92.298 alexnet_validate_loss:  0.536 alexnet_validate_accuracy:  85.710
epoch: 26 alexnet_train_loss:  0.321 alexnet_train_accuracy:  91.825 alexnet_validate_loss:  0.566 alexnet_validate_accuracy:  84.500
epoch: 27 alexnet_train_loss:  0.272 alexnet_train_accuracy:  92.228 alexnet_validate_loss:  0.720 alexnet_validate_accuracy:  76.640
epoch: 28 alexnet_train_loss:  0.348 alexnet_train_accuracy:  90.325 alexnet_validate_loss:  0.638 alexnet_validate_accuracy:  84.810
epoch: 29 alexnet_train_loss:  0.282 alexnet_train_accuracy:  92.220 alexnet_validate_loss:  0.655 alexnet_validate_accuracy:  82.370
epoch: 30 alexnet_train_loss:  0.309 alexnet_train_accuracy:  91.463 alexnet_validate_loss:  0.658 alexnet_validate_accuracy:  83.710
epoch: 31 alexnet_train_loss:  0.280 alexnet_train_accuracy:  91.950 alexnet_validate_loss:  0.549 alexnet_validate_accuracy:  84.320
epoch: 32 alexnet_train_loss:  0.312 alexnet_train_accuracy:  91.923 alexnet_validate_loss:  0.698 alexnet_validate_accuracy:  78.330
epoch: 33 alexnet_train_loss:  0.315 alexnet_train_accuracy:  91.230 alexnet_validate_loss:  0.658 alexnet_validate_accuracy:  85.120
epoch: 34 alexnet_train_loss:  0.320 alexnet_train_accuracy:  91.110 alexnet_validate_loss:  0.799 alexnet_validate_accuracy:  74.940
epoch: 35 alexnet_train_loss:  0.340 alexnet_train_accuracy:  90.615 alexnet_validate_loss:  0.658 alexnet_validate_accuracy:  79.170
epoch: 36 alexnet_train_loss:  0.717 alexnet_train_accuracy:  87.872 alexnet_validate_loss:  0.600 alexnet_validate_accuracy:  84.700
epoch: 37 alexnet_train_loss:  0.397 alexnet_train_accuracy:  89.100 alexnet_validate_loss:  0.738 alexnet_validate_accuracy:  77.710
epoch: 38 alexnet_train_loss:  0.358 alexnet_train_accuracy:  90.077 alexnet_validate_loss:  0.858 alexnet_validate_accuracy:  73.090
epoch: 39 alexnet_train_loss:  0.488 alexnet_train_accuracy:  89.698 alexnet_validate_loss:  5.082 alexnet_validate_accuracy:  84.990
epoch: 40 alexnet_train_loss:  0.473 alexnet_train_accuracy:  86.815 alexnet_validate_loss:  0.636 alexnet_validate_accuracy:  83.720
epoch: 41 alexnet_train_loss:  0.417 alexnet_train_accuracy:  88.455 alexnet_validate_loss:  0.610 alexnet_validate_accuracy:  86.340
epoch: 42 alexnet_train_loss:  0.480 alexnet_train_accuracy:  86.640 alexnet_validate_loss:  0.767 alexnet_validate_accuracy:  79.750
epoch: 43 alexnet_train_loss:  0.414 alexnet_train_accuracy:  88.062 alexnet_validate_loss:  1.541 alexnet_validate_accuracy:  84.850
epoch: 44 alexnet_train_loss:  0.582 alexnet_train_accuracy:  87.033 alexnet_validate_loss:  0.626 alexnet_validate_accuracy:  82.700
epoch: 45 alexnet_train_loss:  0.734 alexnet_train_accuracy:  81.022 alexnet_validate_loss:  0.663 alexnet_validate_accuracy:  81.080
epoch: 46 alexnet_train_loss:  0.468 alexnet_train_accuracy:  87.082 alexnet_validate_loss:  0.882 alexnet_validate_accuracy:  72.010
epoch: 47 alexnet_train_loss:  0.507 alexnet_train_accuracy:  85.895 alexnet_validate_loss:  0.708 alexnet_validate_accuracy:  81.370
epoch: 48 alexnet_train_loss:  0.612 alexnet_train_accuracy:  82.643 alexnet_validate_loss:  0.641 alexnet_validate_accuracy:  80.320
epoch: 49 alexnet_train_loss:  0.674 alexnet_train_accuracy:  80.455 alexnet_validate_loss:  1.058 alexnet_validate_accuracy:  84.870
epoch: 50 alexnet_train_loss:  0.627 alexnet_train_accuracy:  82.395 alexnet_validate_loss:  0.913 alexnet_validate_accuracy:  73.170
epoch: 51 alexnet_train_loss:  0.686 alexnet_train_accuracy:  80.045 alexnet_validate_loss:  1.792 alexnet_validate_accuracy:  79.090
epoch: 52 alexnet_train_loss:  1.001 alexnet_train_accuracy:  76.137 alexnet_validate_loss:  0.776 alexnet_validate_accuracy:  77.090
epoch: 53 alexnet_train_loss:  0.715 alexnet_train_accuracy:  79.103 alexnet_validate_loss:  0.705 alexnet_validate_accuracy:  81.800
epoch: 54 alexnet_train_loss:  0.814 alexnet_train_accuracy:  79.415 alexnet_validate_loss:  0.665 alexnet_validate_accuracy:  80.450
epoch: 55 alexnet_train_loss:  0.749 alexnet_train_accuracy:  78.162 alexnet_validate_loss:  0.740 alexnet_validate_accuracy:  75.410
epoch: 56 alexnet_train_loss:  0.758 alexnet_train_accuracy:  78.040 alexnet_validate_loss:  0.735 alexnet_validate_accuracy:  82.220
epoch: 57 alexnet_train_loss:  0.822 alexnet_train_accuracy:  75.415 alexnet_validate_loss:  0.637 alexnet_validate_accuracy:  81.720
epoch: 58 alexnet_train_loss:  0.968 alexnet_train_accuracy:  70.317 alexnet_validate_loss:  0.664 alexnet_validate_accuracy:  79.950
epoch: 59 alexnet_train_loss:  1.215 alexnet_train_accuracy:  64.935 alexnet_validate_loss:  1.104 alexnet_validate_accuracy:  65.590
epoch: 60 alexnet_train_loss:  0.947 alexnet_train_accuracy:  71.247 alexnet_validate_loss:  3.346 alexnet_validate_accuracy:  69.950
epoch: 61 alexnet_train_loss:  0.882 alexnet_train_accuracy:  73.620 alexnet_validate_loss:  0.836 alexnet_validate_accuracy:  74.500
epoch: 62 alexnet_train_loss:  0.899 alexnet_train_accuracy:  72.747 alexnet_validate_loss:  2.586 alexnet_validate_accuracy:  63.560
epoch: 63 alexnet_train_loss:  1.309 alexnet_train_accuracy:  64.990 alexnet_validate_loss:  1.402 alexnet_validate_accuracy:  51.000
epoch: 64 alexnet_train_loss:  2.877 alexnet_train_accuracy:  55.222 alexnet_validate_loss:  0.876 alexnet_validate_accuracy:  72.660
epoch: 65 alexnet_train_loss:  1.080 alexnet_train_accuracy:  67.098 alexnet_validate_loss:  1.120 alexnet_validate_accuracy:  62.940
epoch: 66 alexnet_train_loss:  1.510 alexnet_train_accuracy:  53.173 alexnet_validate_loss:  1.054 alexnet_validate_accuracy:  64.190
epoch: 67 alexnet_train_loss:  1.904 alexnet_train_accuracy:  36.960 alexnet_validate_loss:  1.695 alexnet_validate_accuracy:  34.830
epoch: 68 alexnet_train_loss:  1.859 alexnet_train_accuracy:  35.350 alexnet_validate_loss:  1.632 alexnet_validate_accuracy:  42.080
epoch: 69 alexnet_train_loss:  2.115 alexnet_train_accuracy:  23.198 alexnet_validate_loss:  1.932 alexnet_validate_accuracy:  18.550
epoch: 70 alexnet_train_loss:  2.123 alexnet_train_accuracy:  20.605 alexnet_validate_loss:  1.955 alexnet_validate_accuracy:  18.630
epoch: 71 alexnet_train_loss:  2.148 alexnet_train_accuracy:  19.675 alexnet_validate_loss:  2.004 alexnet_validate_accuracy:  19.760
epoch: 72 alexnet_train_loss:  1.951 alexnet_train_accuracy:  20.243 alexnet_validate_loss:  1.812 alexnet_validate_accuracy:  25.590
epoch: 73 alexnet_train_loss:  2.050 alexnet_train_accuracy:  19.660 alexnet_validate_loss:  1.994 alexnet_validate_accuracy:  20.080
epoch: 74 alexnet_train_loss:  1.950 alexnet_train_accuracy:  19.325 alexnet_validate_loss:  1.802 alexnet_validate_accuracy:  23.780
epoch: 75 alexnet_train_loss:  1.971 alexnet_train_accuracy:  20.507 alexnet_validate_loss:  1.786 alexnet_validate_accuracy:  23.440
epoch: 76 alexnet_train_loss:  1.848 alexnet_train_accuracy:  22.903 alexnet_validate_loss:  1.746 alexnet_validate_accuracy:  28.540
epoch: 77 alexnet_train_loss:  1.879 alexnet_train_accuracy:  22.945 alexnet_validate_loss:  1.878 alexnet_validate_accuracy:  20.150
epoch: 78 alexnet_train_loss:  2.378 alexnet_train_accuracy:  23.980 alexnet_validate_loss:  1.720 alexnet_validate_accuracy:  31.060
epoch: 79 alexnet_train_loss:  2.087 alexnet_train_accuracy:  25.340 alexnet_validate_loss:  1.715 alexnet_validate_accuracy:  28.690
epoch: 80 alexnet_train_loss:  1.823 alexnet_train_accuracy:  25.550 alexnet_validate_loss:  1.717 alexnet_validate_accuracy:  29.760
epoch: 81 alexnet_train_loss:  1.823 alexnet_train_accuracy:  25.700 alexnet_validate_loss:  1.670 alexnet_validate_accuracy:  30.970
epoch: 82 alexnet_train_loss:  3.502 alexnet_train_accuracy:  27.030 alexnet_validate_loss:  8.780 alexnet_validate_accuracy:  26.950
epoch: 83 alexnet_train_loss:  1.810 alexnet_train_accuracy:  26.185 alexnet_validate_loss:  2.207 alexnet_validate_accuracy:  29.380
epoch: 84 alexnet_train_loss:  3.035 alexnet_train_accuracy:  28.243 alexnet_validate_loss:  1.696 alexnet_validate_accuracy:  25.830
epoch: 85 alexnet_train_loss:  1.731 alexnet_train_accuracy:  29.350 alexnet_validate_loss:  1.970 alexnet_validate_accuracy:  21.500
epoch: 86 alexnet_train_loss:  1.741 alexnet_train_accuracy:  28.772 alexnet_validate_loss:  1.680 alexnet_validate_accuracy:  28.270
epoch: 87 alexnet_train_loss:  1.708 alexnet_train_accuracy:  30.250 alexnet_validate_loss:  1.668 alexnet_validate_accuracy:  31.590
epoch: 88 alexnet_train_loss:  1.743 alexnet_train_accuracy:  30.765 alexnet_validate_loss:  1.650 alexnet_validate_accuracy:  32.550
epoch: 89 alexnet_train_loss:  1.888 alexnet_train_accuracy:  31.335 alexnet_validate_loss:  2.090 alexnet_validate_accuracy:  18.980
epoch: 90 alexnet_train_loss:  1.654 alexnet_train_accuracy:  32.897 alexnet_validate_loss:  1.596 alexnet_validate_accuracy:  34.870
epoch: 91 alexnet_train_loss:  1.676 alexnet_train_accuracy:  33.213 alexnet_validate_loss:  1.739 alexnet_validate_accuracy:  25.260
epoch: 92 alexnet_train_loss:  1.714 alexnet_train_accuracy:  31.305 alexnet_validate_loss:  1.637 alexnet_validate_accuracy:  30.600
epoch: 93 alexnet_train_loss:  2.160 alexnet_train_accuracy:  28.192 alexnet_validate_loss:  1.606 alexnet_validate_accuracy:  35.830
epoch: 94 alexnet_train_loss:  1.741 alexnet_train_accuracy:  29.843 alexnet_validate_loss:  1.731 alexnet_validate_accuracy:  29.230
epoch: 95 alexnet_train_loss:  1.984 alexnet_train_accuracy:  31.427 alexnet_validate_loss:  1.626 alexnet_validate_accuracy:  34.230
epoch: 96 alexnet_train_loss:  2.703 alexnet_train_accuracy:  30.188 alexnet_validate_loss:  1.634 alexnet_validate_accuracy:  34.240
epoch: 97 alexnet_train_loss:  1.707 alexnet_train_accuracy:  31.753 alexnet_validate_loss:  1.677 alexnet_validate_accuracy:  29.910
epoch: 98 alexnet_train_loss:  1.685 alexnet_train_accuracy:  32.290 alexnet_validate_loss:  1.681 alexnet_validate_accuracy:  30.040
epoch: 99 alexnet_train_loss:  1.785 alexnet_train_accuracy:  29.317 alexnet_validate_loss:  1.579 alexnet_validate_accuracy:  33.700
epoch: 100 alexnet_train_loss:  1.759 alexnet_train_accuracy:  30.528 alexnet_validate_loss:  1.675 alexnet_validate_accuracy:  31.120

4.2 Fine-tuning model with frozen layers¶

Configuration 2: Frozen base convolution blocks

In [23]:
# Your changes here - also print trainable parameters
frozen_alexnet = models.alexnet(pretrained=True)
num_fc = frozen_alexnet.classifier[6].in_features
frozen_alexnet.classifier[6] = nn.Linear(in_features=num_fc, out_features=10)
frozen_alexnet = frozen_alexnet.to(device)
for param in frozen_alexnet.parameters():
    param.requires_grad = False
for param in frozen_alexnet.classifier[6].parameters():
    param.requires_grad = True
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(frozen_alexnet.parameters(), lr=0.001)

total_params = sum(p.numel() for p in frozen_alexnet.parameters())
print(f'{total_params:,} total parameters.')
total_trainable_params = sum(
    p.numel() for p in frozen_alexnet.parameters() if p.requires_grad)
print(f'{total_trainable_params:,} trainable parameters.')

frozen_alexnet_best_loss = 1000
frozen_alexnet_train_loss, frozen_alexnet_validate_loss, frozen_alexnet_train_accuracy, frozen_alexnet_validate_accuracy = [], [], [], []
nepochs = 100
for epoch in range(nepochs):
    frozen_alexnet_train_running_loss , frozen_alexnet_train_running_accuracy = train(CIFAR10trainloader, frozen_alexnet, criterion, optimizer)
    frozen_alexnet_train_loss.append(frozen_alexnet_train_running_loss)
    frozen_alexnet_train_accuracy.append(frozen_alexnet_train_running_accuracy)
    frozen_alexnet_validate_running_loss , frozen_alexnet_validate_running_accuracy = validate(CIFAR10validateloader, frozen_alexnet, criterion, optimizer)
    frozen_alexnet_validate_loss.append(frozen_alexnet_validate_running_loss)
    frozen_alexnet_validate_accuracy.append(frozen_alexnet_validate_running_accuracy)
    if frozen_alexnet_validate_running_loss < frozen_alexnet_best_loss:
        frozen_alexnet_best_loss = frozen_alexnet_validate_running_loss
        torch.save(frozen_alexnet.state_dict(), './frozen_alexnet.pt')
    print(f"epoch: {epoch+1} alexnet_train_loss: {frozen_alexnet_train_loss[epoch] : .3f} frozen_alexnet_train_accuracy: {frozen_alexnet_train_accuracy[epoch] : .3f} frozen_alexnet_validate_loss: {frozen_alexnet_validate_loss[epoch] : .3f} frozen_alexnet_validate_accuracy: {frozen_alexnet_validate_accuracy[epoch] : .3f}")  
57,044,810 total parameters.
40,970 trainable parameters.
epoch: 1 alexnet_train_loss:  1.221 frozen_alexnet_train_accuracy:  65.088 frozen_alexnet_validate_loss:  1.013 frozen_alexnet_validate_accuracy:  70.520
epoch: 2 alexnet_train_loss:  1.211 frozen_alexnet_train_accuracy:  68.207 frozen_alexnet_validate_loss:  0.986 frozen_alexnet_validate_accuracy:  72.110
epoch: 3 alexnet_train_loss:  1.225 frozen_alexnet_train_accuracy:  69.043 frozen_alexnet_validate_loss:  1.026 frozen_alexnet_validate_accuracy:  72.350
epoch: 4 alexnet_train_loss:  1.227 frozen_alexnet_train_accuracy:  69.412 frozen_alexnet_validate_loss:  1.055 frozen_alexnet_validate_accuracy:  71.560
epoch: 5 alexnet_train_loss:  1.233 frozen_alexnet_train_accuracy:  69.603 frozen_alexnet_validate_loss:  1.062 frozen_alexnet_validate_accuracy:  72.770
epoch: 6 alexnet_train_loss:  1.228 frozen_alexnet_train_accuracy:  70.062 frozen_alexnet_validate_loss:  1.000 frozen_alexnet_validate_accuracy:  74.160
epoch: 7 alexnet_train_loss:  1.235 frozen_alexnet_train_accuracy:  70.420 frozen_alexnet_validate_loss:  1.150 frozen_alexnet_validate_accuracy:  72.020
epoch: 8 alexnet_train_loss:  1.228 frozen_alexnet_train_accuracy:  70.717 frozen_alexnet_validate_loss:  1.062 frozen_alexnet_validate_accuracy:  73.700
epoch: 9 alexnet_train_loss:  1.255 frozen_alexnet_train_accuracy:  70.018 frozen_alexnet_validate_loss:  1.055 frozen_alexnet_validate_accuracy:  72.700
epoch: 10 alexnet_train_loss:  1.236 frozen_alexnet_train_accuracy:  70.935 frozen_alexnet_validate_loss:  1.181 frozen_alexnet_validate_accuracy:  72.630
epoch: 11 alexnet_train_loss:  1.272 frozen_alexnet_train_accuracy:  70.298 frozen_alexnet_validate_loss:  1.038 frozen_alexnet_validate_accuracy:  73.880
epoch: 12 alexnet_train_loss:  1.255 frozen_alexnet_train_accuracy:  70.700 frozen_alexnet_validate_loss:  1.059 frozen_alexnet_validate_accuracy:  73.430
epoch: 13 alexnet_train_loss:  1.253 frozen_alexnet_train_accuracy:  70.692 frozen_alexnet_validate_loss:  0.997 frozen_alexnet_validate_accuracy:  74.390
epoch: 14 alexnet_train_loss:  1.252 frozen_alexnet_train_accuracy:  71.062 frozen_alexnet_validate_loss:  0.960 frozen_alexnet_validate_accuracy:  75.530
epoch: 15 alexnet_train_loss:  1.261 frozen_alexnet_train_accuracy:  70.690 frozen_alexnet_validate_loss:  1.064 frozen_alexnet_validate_accuracy:  73.250
epoch: 16 alexnet_train_loss:  1.251 frozen_alexnet_train_accuracy:  71.050 frozen_alexnet_validate_loss:  1.081 frozen_alexnet_validate_accuracy:  73.700
epoch: 17 alexnet_train_loss:  1.249 frozen_alexnet_train_accuracy:  71.442 frozen_alexnet_validate_loss:  1.227 frozen_alexnet_validate_accuracy:  71.940
epoch: 18 alexnet_train_loss:  1.283 frozen_alexnet_train_accuracy:  70.652 frozen_alexnet_validate_loss:  0.970 frozen_alexnet_validate_accuracy:  75.880
epoch: 19 alexnet_train_loss:  1.272 frozen_alexnet_train_accuracy:  70.777 frozen_alexnet_validate_loss:  1.234 frozen_alexnet_validate_accuracy:  71.300
epoch: 20 alexnet_train_loss:  1.252 frozen_alexnet_train_accuracy:  70.905 frozen_alexnet_validate_loss:  1.022 frozen_alexnet_validate_accuracy:  74.850
epoch: 21 alexnet_train_loss:  1.263 frozen_alexnet_train_accuracy:  71.228 frozen_alexnet_validate_loss:  0.971 frozen_alexnet_validate_accuracy:  75.340
epoch: 22 alexnet_train_loss:  1.269 frozen_alexnet_train_accuracy:  71.353 frozen_alexnet_validate_loss:  1.084 frozen_alexnet_validate_accuracy:  73.930
epoch: 23 alexnet_train_loss:  1.263 frozen_alexnet_train_accuracy:  71.058 frozen_alexnet_validate_loss:  1.166 frozen_alexnet_validate_accuracy:  72.880
epoch: 24 alexnet_train_loss:  1.274 frozen_alexnet_train_accuracy:  71.213 frozen_alexnet_validate_loss:  0.962 frozen_alexnet_validate_accuracy:  75.790
epoch: 25 alexnet_train_loss:  1.268 frozen_alexnet_train_accuracy:  71.150 frozen_alexnet_validate_loss:  0.943 frozen_alexnet_validate_accuracy:  76.760
epoch: 26 alexnet_train_loss:  1.278 frozen_alexnet_train_accuracy:  71.027 frozen_alexnet_validate_loss:  1.158 frozen_alexnet_validate_accuracy:  72.650
epoch: 27 alexnet_train_loss:  1.272 frozen_alexnet_train_accuracy:  71.143 frozen_alexnet_validate_loss:  1.019 frozen_alexnet_validate_accuracy:  75.980
epoch: 28 alexnet_train_loss:  1.264 frozen_alexnet_train_accuracy:  71.363 frozen_alexnet_validate_loss:  1.013 frozen_alexnet_validate_accuracy:  75.070
epoch: 29 alexnet_train_loss:  1.276 frozen_alexnet_train_accuracy:  71.065 frozen_alexnet_validate_loss:  0.982 frozen_alexnet_validate_accuracy:  75.500
epoch: 30 alexnet_train_loss:  1.278 frozen_alexnet_train_accuracy:  71.158 frozen_alexnet_validate_loss:  1.135 frozen_alexnet_validate_accuracy:  73.710
epoch: 31 alexnet_train_loss:  1.285 frozen_alexnet_train_accuracy:  71.522 frozen_alexnet_validate_loss:  1.156 frozen_alexnet_validate_accuracy:  72.690
epoch: 32 alexnet_train_loss:  1.276 frozen_alexnet_train_accuracy:  71.247 frozen_alexnet_validate_loss:  1.176 frozen_alexnet_validate_accuracy:  73.220
epoch: 33 alexnet_train_loss:  1.283 frozen_alexnet_train_accuracy:  71.092 frozen_alexnet_validate_loss:  1.096 frozen_alexnet_validate_accuracy:  74.230
epoch: 34 alexnet_train_loss:  1.274 frozen_alexnet_train_accuracy:  71.092 frozen_alexnet_validate_loss:  0.950 frozen_alexnet_validate_accuracy:  75.980
epoch: 35 alexnet_train_loss:  1.284 frozen_alexnet_train_accuracy:  71.173 frozen_alexnet_validate_loss:  1.071 frozen_alexnet_validate_accuracy:  74.110
epoch: 36 alexnet_train_loss:  1.276 frozen_alexnet_train_accuracy:  71.402 frozen_alexnet_validate_loss:  1.087 frozen_alexnet_validate_accuracy:  74.430
epoch: 37 alexnet_train_loss:  1.270 frozen_alexnet_train_accuracy:  71.510 frozen_alexnet_validate_loss:  1.233 frozen_alexnet_validate_accuracy:  72.780
epoch: 38 alexnet_train_loss:  1.280 frozen_alexnet_train_accuracy:  71.380 frozen_alexnet_validate_loss:  0.918 frozen_alexnet_validate_accuracy:  76.990
epoch: 39 alexnet_train_loss:  1.274 frozen_alexnet_train_accuracy:  71.555 frozen_alexnet_validate_loss:  1.076 frozen_alexnet_validate_accuracy:  74.100
epoch: 40 alexnet_train_loss:  1.266 frozen_alexnet_train_accuracy:  71.777 frozen_alexnet_validate_loss:  1.021 frozen_alexnet_validate_accuracy:  75.480
epoch: 41 alexnet_train_loss:  1.289 frozen_alexnet_train_accuracy:  70.963 frozen_alexnet_validate_loss:  0.961 frozen_alexnet_validate_accuracy:  76.430
epoch: 42 alexnet_train_loss:  1.286 frozen_alexnet_train_accuracy:  71.382 frozen_alexnet_validate_loss:  1.079 frozen_alexnet_validate_accuracy:  74.010
epoch: 43 alexnet_train_loss:  1.273 frozen_alexnet_train_accuracy:  71.393 frozen_alexnet_validate_loss:  1.062 frozen_alexnet_validate_accuracy:  75.150
epoch: 44 alexnet_train_loss:  1.283 frozen_alexnet_train_accuracy:  71.365 frozen_alexnet_validate_loss:  0.995 frozen_alexnet_validate_accuracy:  75.750
epoch: 45 alexnet_train_loss:  1.270 frozen_alexnet_train_accuracy:  71.560 frozen_alexnet_validate_loss:  1.065 frozen_alexnet_validate_accuracy:  73.920
epoch: 46 alexnet_train_loss:  1.287 frozen_alexnet_train_accuracy:  71.393 frozen_alexnet_validate_loss:  1.020 frozen_alexnet_validate_accuracy:  75.250
epoch: 47 alexnet_train_loss:  1.287 frozen_alexnet_train_accuracy:  71.507 frozen_alexnet_validate_loss:  1.338 frozen_alexnet_validate_accuracy:  71.470
epoch: 48 alexnet_train_loss:  1.291 frozen_alexnet_train_accuracy:  71.433 frozen_alexnet_validate_loss:  0.999 frozen_alexnet_validate_accuracy:  75.370
epoch: 49 alexnet_train_loss:  1.277 frozen_alexnet_train_accuracy:  71.548 frozen_alexnet_validate_loss:  1.393 frozen_alexnet_validate_accuracy:  70.220
epoch: 50 alexnet_train_loss:  1.282 frozen_alexnet_train_accuracy:  71.548 frozen_alexnet_validate_loss:  0.990 frozen_alexnet_validate_accuracy:  75.620
epoch: 51 alexnet_train_loss:  1.288 frozen_alexnet_train_accuracy:  71.418 frozen_alexnet_validate_loss:  0.969 frozen_alexnet_validate_accuracy:  75.800
epoch: 52 alexnet_train_loss:  1.279 frozen_alexnet_train_accuracy:  71.338 frozen_alexnet_validate_loss:  0.987 frozen_alexnet_validate_accuracy:  76.050
epoch: 53 alexnet_train_loss:  1.289 frozen_alexnet_train_accuracy:  71.442 frozen_alexnet_validate_loss:  1.145 frozen_alexnet_validate_accuracy:  73.350
epoch: 54 alexnet_train_loss:  1.277 frozen_alexnet_train_accuracy:  71.610 frozen_alexnet_validate_loss:  1.040 frozen_alexnet_validate_accuracy:  74.970
epoch: 55 alexnet_train_loss:  1.273 frozen_alexnet_train_accuracy:  71.495 frozen_alexnet_validate_loss:  1.073 frozen_alexnet_validate_accuracy:  73.670
epoch: 56 alexnet_train_loss:  1.296 frozen_alexnet_train_accuracy:  71.390 frozen_alexnet_validate_loss:  1.147 frozen_alexnet_validate_accuracy:  73.320
epoch: 57 alexnet_train_loss:  1.287 frozen_alexnet_train_accuracy:  71.395 frozen_alexnet_validate_loss:  1.131 frozen_alexnet_validate_accuracy:  72.870
epoch: 58 alexnet_train_loss:  1.285 frozen_alexnet_train_accuracy:  71.520 frozen_alexnet_validate_loss:  1.129 frozen_alexnet_validate_accuracy:  73.050
epoch: 59 alexnet_train_loss:  1.295 frozen_alexnet_train_accuracy:  71.257 frozen_alexnet_validate_loss:  1.079 frozen_alexnet_validate_accuracy:  73.970
epoch: 60 alexnet_train_loss:  1.296 frozen_alexnet_train_accuracy:  71.430 frozen_alexnet_validate_loss:  1.073 frozen_alexnet_validate_accuracy:  74.150
epoch: 61 alexnet_train_loss:  1.285 frozen_alexnet_train_accuracy:  71.625 frozen_alexnet_validate_loss:  0.973 frozen_alexnet_validate_accuracy:  75.700
epoch: 62 alexnet_train_loss:  1.275 frozen_alexnet_train_accuracy:  71.452 frozen_alexnet_validate_loss:  1.129 frozen_alexnet_validate_accuracy:  73.560
epoch: 63 alexnet_train_loss:  1.294 frozen_alexnet_train_accuracy:  71.283 frozen_alexnet_validate_loss:  1.064 frozen_alexnet_validate_accuracy:  74.510
epoch: 64 alexnet_train_loss:  1.302 frozen_alexnet_train_accuracy:  71.350 frozen_alexnet_validate_loss:  1.124 frozen_alexnet_validate_accuracy:  72.980
epoch: 65 alexnet_train_loss:  1.287 frozen_alexnet_train_accuracy:  71.228 frozen_alexnet_validate_loss:  1.002 frozen_alexnet_validate_accuracy:  75.530
epoch: 66 alexnet_train_loss:  1.289 frozen_alexnet_train_accuracy:  71.675 frozen_alexnet_validate_loss:  1.023 frozen_alexnet_validate_accuracy:  75.140
epoch: 67 alexnet_train_loss:  1.302 frozen_alexnet_train_accuracy:  71.180 frozen_alexnet_validate_loss:  0.964 frozen_alexnet_validate_accuracy:  76.340
epoch: 68 alexnet_train_loss:  1.290 frozen_alexnet_train_accuracy:  71.615 frozen_alexnet_validate_loss:  1.032 frozen_alexnet_validate_accuracy:  74.880
epoch: 69 alexnet_train_loss:  1.299 frozen_alexnet_train_accuracy:  71.375 frozen_alexnet_validate_loss:  1.012 frozen_alexnet_validate_accuracy:  75.560
epoch: 70 alexnet_train_loss:  1.294 frozen_alexnet_train_accuracy:  71.412 frozen_alexnet_validate_loss:  0.999 frozen_alexnet_validate_accuracy:  76.100
epoch: 71 alexnet_train_loss:  1.299 frozen_alexnet_train_accuracy:  71.085 frozen_alexnet_validate_loss:  1.129 frozen_alexnet_validate_accuracy:  73.600
epoch: 72 alexnet_train_loss:  1.290 frozen_alexnet_train_accuracy:  71.332 frozen_alexnet_validate_loss:  1.012 frozen_alexnet_validate_accuracy:  74.650
epoch: 73 alexnet_train_loss:  1.286 frozen_alexnet_train_accuracy:  71.215 frozen_alexnet_validate_loss:  1.049 frozen_alexnet_validate_accuracy:  75.230
epoch: 74 alexnet_train_loss:  1.284 frozen_alexnet_train_accuracy:  71.660 frozen_alexnet_validate_loss:  0.970 frozen_alexnet_validate_accuracy:  75.570
epoch: 75 alexnet_train_loss:  1.280 frozen_alexnet_train_accuracy:  71.757 frozen_alexnet_validate_loss:  1.088 frozen_alexnet_validate_accuracy:  74.000
epoch: 76 alexnet_train_loss:  1.294 frozen_alexnet_train_accuracy:  71.357 frozen_alexnet_validate_loss:  1.112 frozen_alexnet_validate_accuracy:  73.980
epoch: 77 alexnet_train_loss:  1.297 frozen_alexnet_train_accuracy:  71.327 frozen_alexnet_validate_loss:  1.116 frozen_alexnet_validate_accuracy:  73.670
epoch: 78 alexnet_train_loss:  1.304 frozen_alexnet_train_accuracy:  71.510 frozen_alexnet_validate_loss:  1.069 frozen_alexnet_validate_accuracy:  74.740
epoch: 79 alexnet_train_loss:  1.295 frozen_alexnet_train_accuracy:  71.543 frozen_alexnet_validate_loss:  1.163 frozen_alexnet_validate_accuracy:  73.060
epoch: 80 alexnet_train_loss:  1.282 frozen_alexnet_train_accuracy:  71.673 frozen_alexnet_validate_loss:  1.058 frozen_alexnet_validate_accuracy:  75.470
epoch: 81 alexnet_train_loss:  1.303 frozen_alexnet_train_accuracy:  71.605 frozen_alexnet_validate_loss:  1.134 frozen_alexnet_validate_accuracy:  73.480
epoch: 82 alexnet_train_loss:  1.285 frozen_alexnet_train_accuracy:  71.640 frozen_alexnet_validate_loss:  1.027 frozen_alexnet_validate_accuracy:  75.220
epoch: 83 alexnet_train_loss:  1.295 frozen_alexnet_train_accuracy:  71.637 frozen_alexnet_validate_loss:  1.229 frozen_alexnet_validate_accuracy:  72.680
epoch: 84 alexnet_train_loss:  1.280 frozen_alexnet_train_accuracy:  71.610 frozen_alexnet_validate_loss:  1.176 frozen_alexnet_validate_accuracy:  73.700
epoch: 85 alexnet_train_loss:  1.297 frozen_alexnet_train_accuracy:  71.115 frozen_alexnet_validate_loss:  1.343 frozen_alexnet_validate_accuracy:  71.130
epoch: 86 alexnet_train_loss:  1.292 frozen_alexnet_train_accuracy:  71.495 frozen_alexnet_validate_loss:  1.181 frozen_alexnet_validate_accuracy:  72.380
epoch: 87 alexnet_train_loss:  1.278 frozen_alexnet_train_accuracy:  71.567 frozen_alexnet_validate_loss:  1.094 frozen_alexnet_validate_accuracy:  73.420
epoch: 88 alexnet_train_loss:  1.293 frozen_alexnet_train_accuracy:  71.357 frozen_alexnet_validate_loss:  1.094 frozen_alexnet_validate_accuracy:  74.610
epoch: 89 alexnet_train_loss:  1.308 frozen_alexnet_train_accuracy:  71.317 frozen_alexnet_validate_loss:  1.055 frozen_alexnet_validate_accuracy:  74.620
epoch: 90 alexnet_train_loss:  1.303 frozen_alexnet_train_accuracy:  71.408 frozen_alexnet_validate_loss:  1.019 frozen_alexnet_validate_accuracy:  76.000
epoch: 91 alexnet_train_loss:  1.286 frozen_alexnet_train_accuracy:  71.310 frozen_alexnet_validate_loss:  1.136 frozen_alexnet_validate_accuracy:  73.380
epoch: 92 alexnet_train_loss:  1.282 frozen_alexnet_train_accuracy:  71.622 frozen_alexnet_validate_loss:  1.216 frozen_alexnet_validate_accuracy:  72.250
epoch: 93 alexnet_train_loss:  1.294 frozen_alexnet_train_accuracy:  71.545 frozen_alexnet_validate_loss:  1.079 frozen_alexnet_validate_accuracy:  74.420
epoch: 94 alexnet_train_loss:  1.306 frozen_alexnet_train_accuracy:  71.435 frozen_alexnet_validate_loss:  0.997 frozen_alexnet_validate_accuracy:  75.490
epoch: 95 alexnet_train_loss:  1.282 frozen_alexnet_train_accuracy:  71.442 frozen_alexnet_validate_loss:  1.135 frozen_alexnet_validate_accuracy:  73.330
epoch: 96 alexnet_train_loss:  1.306 frozen_alexnet_train_accuracy:  71.442 frozen_alexnet_validate_loss:  1.148 frozen_alexnet_validate_accuracy:  72.790
epoch: 97 alexnet_train_loss:  1.292 frozen_alexnet_train_accuracy:  71.410 frozen_alexnet_validate_loss:  0.954 frozen_alexnet_validate_accuracy:  76.570
epoch: 98 alexnet_train_loss:  1.301 frozen_alexnet_train_accuracy:  71.357 frozen_alexnet_validate_loss:  1.255 frozen_alexnet_validate_accuracy:  71.920
epoch: 99 alexnet_train_loss:  1.271 frozen_alexnet_train_accuracy:  72.003 frozen_alexnet_validate_loss:  1.179 frozen_alexnet_validate_accuracy:  72.660
epoch: 100 alexnet_train_loss:  1.304 frozen_alexnet_train_accuracy:  71.537 frozen_alexnet_validate_loss:  1.168 frozen_alexnet_validate_accuracy:  72.030

4.3 Compare above configurations and comment on comparative performance¶

In [26]:
# Your graphs here and please provide comment in markdown in another cell
x_axis = np.arange(1,nepochs+1,1,int)
fig,axs=plt.subplots(2,2,figsize=(15,20),sharex=True,sharey=False)
fig.suptitle('The loss and accuracy of train and validate sets with and without frozen')

axs[0][0].plot(x_axis,alexnet_train_loss,label='alexnet train_loss')
axs[0][0].plot(x_axis,frozen_alexnet_train_loss,label='frozen_alexnet train_loss')
axs[0][1].plot(x_axis,alexnet_validate_loss,label='alexnet validate_loss')
axs[0][1].plot(x_axis,frozen_alexnet_validate_loss,label='frozen_alexnet validate_loss')
axs[1][0].plot(x_axis,alexnet_train_accuracy,label='alexnet train_accuracy')
axs[1][0].plot(x_axis,frozen_alexnet_train_accuracy,label='frozen_alexnet train_accuracy')

axs[1][1].plot(x_axis,alexnet_validate_accuracy,label='alexnet validate_accuracy')
axs[1][1].plot(x_axis,frozen_alexnet_validate_accuracy,label='frozen_alexnet validate_accuracy')

axs[1][0].set_xlabel('epoch')
axs[1][1].set_xlabel('epoch')
axs[0][0].set_ylabel('loss')
axs[0][1].set_ylabel('loss')
axs[1][0].set_ylabel('percentage of accuracy')
axs[1][1].set_ylabel('percentage of accuracy')
axs[0][0].legend()
axs[0][1].legend()
axs[1][1].legend()
axs[1][0].legend()
plt.show()

If the convolution layer is not frozen and pre-trained weights are imported. The weights of the original AlexNet network are overwritten by the new data after the training epoch rises causing a severe bias, hence a steep drop in Accuracy and a steep rise and fluctuation in Loss.

5 Model comparisons¶

We often need to compare our model with other state-of-the-art methods to understand how well it performs compared to existing architectures. Here you will thus compare your model design with AlexNet on the TinyImageNet30 dataset

5.1 Finetune AlexNet on TinyImageNet30¶

Load AlexNet as you did above

Train AlexNet on TinyImageNet30 dataset until convergence. Make sure you use the same dataset

In [14]:
# Your code here! 
alexnet_compare = models.alexnet(pretrained=True)
com_num_fc = alexnet_compare.classifier[6].in_features
alexnet_compare.classifier[6] = torch.nn.Linear(in_features=com_num_fc, out_features=30)
alexnet_compare = alexnet_compare.to(device)
print(alexnet_compare )
AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
    (2): ReLU(inplace=True)
    (3): Dropout(p=0.5, inplace=False)
    (4): Linear(in_features=4096, out_features=4096, bias=True)
    (5): ReLU(inplace=True)
    (6): Linear(in_features=4096, out_features=30, bias=True)
  )
)
In [15]:
data_augmentation_transform = transforms.Compose([
                    transforms.ToPILImage(),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.RandomVerticalFlip(p=0.5),
                    transforms.RandomRotation((-20,20)),
                    transforms.ColorJitter(hue=0.2, saturation=0.2, brightness=0.2),
                    transforms.Resize(224),
                    transforms.ToTensor(),
                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
                    ])

train_set = MyDataset("train_set",transform=data_augmentation_transform)
length=len(train_set)
train_size,validate_size=int(0.8*length),int(0.2*length)
train_set,validate_set=torch.utils.data.random_split(train_set,[train_size,validate_size],generator=torch.Generator().manual_seed(0))

train_loader = DataLoader(
    train_set,
    batch_size = 64,
    shuffle = True)
validate_loader = DataLoader(
    validate_set,
    batch_size = 64,
    shuffle = True)
In [17]:
# Your code here! 
alexnet_start = time.time()
nepochs = 100
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(alexnet_compare.parameters(), 0.0005)

alexnet_compare_best_loss = 1000.0
alexnet_compare_train_loss, alexnet_compare_train_accuracy= [], []
alexnet_compare_validate_loss, alexnet_compare_validate_accuracy= [], []

for epoch in range(nepochs):
    alexnet_compare_train_running_loss , alexnet_compare_train_running_accuracy = train(train_loader, alexnet_compare, criterion, optimizer)
    alexnet_compare_train_loss.append(alexnet_compare_train_running_loss)
    alexnet_compare_train_accuracy.append(alexnet_compare_train_running_accuracy)
    alexnet_compare_validate_running_loss , alexnet_compare_validate_running_accuracy = validate(validate_loader, alexnet_compare, criterion, optimizer)
    alexnet_compare_validate_loss.append(alexnet_compare_validate_running_loss)
    alexnet_compare_validate_accuracy.append(alexnet_compare_validate_running_accuracy)
    if alexnet_compare_validate_running_loss < alexnet_compare_best_loss:
        alexnet_compare_best_loss = alexnet_compare_validate_running_loss
        torch.save(alexnet_compare.state_dict(), './alexnet_compare.pt')
    print(f"epoch: {epoch+1} train_loss: {alexnet_compare_train_running_loss : .3f} train_accuracy: {alexnet_compare_train_running_accuracy : .3f} validate_loss: {alexnet_compare_validate_running_loss : .3f} validate_accuracy: {alexnet_compare_validate_running_accuracy : .3f}")
    
alexnet_end = time.time()
alexnet_running_time = alexnet_end - alexnet_start
epoch: 1 train_loss:  0.998 train_accuracy:  70.368 validate_loss:  1.710 validate_accuracy:  55.136
epoch: 2 train_loss:  0.993 train_accuracy:  70.602 validate_loss:  1.614 validate_accuracy:  57.679
epoch: 3 train_loss:  0.963 train_accuracy:  71.240 validate_loss:  1.662 validate_accuracy:  56.020
epoch: 4 train_loss:  0.999 train_accuracy:  70.648 validate_loss:  1.652 validate_accuracy:  56.565
epoch: 5 train_loss:  1.011 train_accuracy:  70.115 validate_loss:  1.563 validate_accuracy:  56.904
epoch: 6 train_loss:  0.963 train_accuracy:  71.320 validate_loss:  1.634 validate_accuracy:  56.880
epoch: 7 train_loss:  0.958 train_accuracy:  71.083 validate_loss:  1.595 validate_accuracy:  56.795
epoch: 8 train_loss:  0.950 train_accuracy:  71.893 validate_loss:  1.600 validate_accuracy:  56.831
epoch: 9 train_loss:  0.943 train_accuracy:  71.650 validate_loss:  1.632 validate_accuracy:  57.376
epoch: 10 train_loss:  0.946 train_accuracy:  71.946 validate_loss:  1.660 validate_accuracy:  55.729
epoch: 11 train_loss:  0.940 train_accuracy:  71.875 validate_loss:  1.634 validate_accuracy:  55.984
epoch: 12 train_loss:  0.922 train_accuracy:  72.482 validate_loss:  1.561 validate_accuracy:  57.219
epoch: 13 train_loss:  0.940 train_accuracy:  71.761 validate_loss:  1.674 validate_accuracy:  55.293
epoch: 14 train_loss:  0.961 train_accuracy:  71.320 validate_loss:  1.668 validate_accuracy:  55.608
epoch: 15 train_loss:  0.968 train_accuracy:  71.003 validate_loss:  1.598 validate_accuracy:  56.468
epoch: 16 train_loss:  0.919 train_accuracy:  71.912 validate_loss:  1.666 validate_accuracy:  55.608
epoch: 17 train_loss:  0.950 train_accuracy:  72.149 validate_loss:  1.644 validate_accuracy:  54.603
epoch: 18 train_loss:  0.970 train_accuracy:  71.071 validate_loss:  1.620 validate_accuracy:  56.953
epoch: 19 train_loss:  0.900 train_accuracy:  72.673 validate_loss:  1.652 validate_accuracy:  56.383
epoch: 20 train_loss:  0.937 train_accuracy:  71.712 validate_loss:  1.646 validate_accuracy:  56.068
epoch: 21 train_loss:  0.943 train_accuracy:  72.343 validate_loss:  1.673 validate_accuracy:  55.402
epoch: 22 train_loss:  0.944 train_accuracy:  71.739 validate_loss:  1.736 validate_accuracy:  56.044
epoch: 23 train_loss:  0.945 train_accuracy:  71.869 validate_loss:  1.681 validate_accuracy:  55.414
epoch: 24 train_loss:  0.911 train_accuracy:  72.698 validate_loss:  1.609 validate_accuracy:  57.607
epoch: 25 train_loss:  0.949 train_accuracy:  71.505 validate_loss:  1.745 validate_accuracy:  54.082
epoch: 26 train_loss:  0.931 train_accuracy:  71.992 validate_loss:  1.728 validate_accuracy:  54.566
epoch: 27 train_loss:  0.886 train_accuracy:  73.552 validate_loss:  1.671 validate_accuracy:  55.705
epoch: 28 train_loss:  0.939 train_accuracy:  72.467 validate_loss:  1.692 validate_accuracy:  56.214
epoch: 29 train_loss:  0.900 train_accuracy:  72.840 validate_loss:  1.671 validate_accuracy:  54.797
epoch: 30 train_loss:  0.908 train_accuracy:  73.018 validate_loss:  1.673 validate_accuracy:  55.366
epoch: 31 train_loss:  0.914 train_accuracy:  73.277 validate_loss:  1.584 validate_accuracy:  56.844
epoch: 32 train_loss:  0.923 train_accuracy:  72.020 validate_loss:  1.683 validate_accuracy:  55.584
epoch: 33 train_loss:  0.885 train_accuracy:  73.444 validate_loss:  1.708 validate_accuracy:  56.407
epoch: 34 train_loss:  0.870 train_accuracy:  73.632 validate_loss:  1.626 validate_accuracy:  56.202
epoch: 35 train_loss:  0.897 train_accuracy:  73.391 validate_loss:  1.676 validate_accuracy:  55.208
epoch: 36 train_loss:  0.915 train_accuracy:  73.074 validate_loss:  1.705 validate_accuracy:  53.949
epoch: 37 train_loss:  0.931 train_accuracy:  72.759 validate_loss:  1.755 validate_accuracy:  54.542
epoch: 38 train_loss:  0.893 train_accuracy:  73.422 validate_loss:  1.773 validate_accuracy:  56.456
epoch: 39 train_loss:  0.895 train_accuracy:  73.083 validate_loss:  1.697 validate_accuracy:  56.202
epoch: 40 train_loss:  0.893 train_accuracy:  72.926 validate_loss:  1.743 validate_accuracy:  56.056
epoch: 41 train_loss:  0.856 train_accuracy:  73.980 validate_loss:  1.738 validate_accuracy:  55.426
epoch: 42 train_loss:  0.881 train_accuracy:  73.968 validate_loss:  1.749 validate_accuracy:  55.233
epoch: 43 train_loss:  0.906 train_accuracy:  73.213 validate_loss:  1.688 validate_accuracy:  55.935
epoch: 44 train_loss:  0.906 train_accuracy:  73.323 validate_loss:  1.706 validate_accuracy:  56.177
epoch: 45 train_loss:  0.911 train_accuracy:  73.240 validate_loss:  1.714 validate_accuracy:  55.402
epoch: 46 train_loss:  0.921 train_accuracy:  72.673 validate_loss:  1.751 validate_accuracy:  53.634
epoch: 47 train_loss:  0.854 train_accuracy:  74.353 validate_loss:  1.711 validate_accuracy:  57.025
epoch: 48 train_loss:  0.919 train_accuracy:  73.077 validate_loss:  1.722 validate_accuracy:  54.966
epoch: 49 train_loss:  0.897 train_accuracy:  73.345 validate_loss:  1.759 validate_accuracy:  55.390
epoch: 50 train_loss:  0.887 train_accuracy:  73.780 validate_loss:  1.702 validate_accuracy:  56.444
epoch: 51 train_loss:  0.864 train_accuracy:  74.270 validate_loss:  1.694 validate_accuracy:  56.940
epoch: 52 train_loss:  0.863 train_accuracy:  74.458 validate_loss:  1.749 validate_accuracy:  54.203
epoch: 53 train_loss:  0.879 train_accuracy:  74.060 validate_loss:  1.749 validate_accuracy:  53.973
epoch: 54 train_loss:  0.866 train_accuracy:  73.937 validate_loss:  1.733 validate_accuracy:  54.336
epoch: 55 train_loss:  0.897 train_accuracy:  73.555 validate_loss:  1.700 validate_accuracy:  55.051
epoch: 56 train_loss:  0.920 train_accuracy:  72.602 validate_loss:  1.690 validate_accuracy:  55.463
epoch: 57 train_loss:  0.850 train_accuracy:  75.099 validate_loss:  1.699 validate_accuracy:  56.953
epoch: 58 train_loss:  0.867 train_accuracy:  74.011 validate_loss:  1.647 validate_accuracy:  55.959
epoch: 59 train_loss:  0.839 train_accuracy:  75.398 validate_loss:  1.748 validate_accuracy:  56.989
epoch: 60 train_loss:  0.854 train_accuracy:  74.726 validate_loss:  1.865 validate_accuracy:  55.632
epoch: 61 train_loss:  0.899 train_accuracy:  74.005 validate_loss:  1.741 validate_accuracy:  56.008
epoch: 62 train_loss:  0.846 train_accuracy:  74.661 validate_loss:  1.751 validate_accuracy:  57.437
epoch: 63 train_loss:  0.884 train_accuracy:  74.140 validate_loss:  1.741 validate_accuracy:  55.402
epoch: 64 train_loss:  0.822 train_accuracy:  75.515 validate_loss:  1.789 validate_accuracy:  55.838
epoch: 65 train_loss:  0.872 train_accuracy:  74.670 validate_loss:  1.762 validate_accuracy:  54.494
epoch: 66 train_loss:  0.908 train_accuracy:  73.576 validate_loss:  1.690 validate_accuracy:  55.620
epoch: 67 train_loss:  0.861 train_accuracy:  74.686 validate_loss:  1.708 validate_accuracy:  56.189
epoch: 68 train_loss:  0.861 train_accuracy:  74.593 validate_loss:  1.739 validate_accuracy:  55.826
epoch: 69 train_loss:  0.884 train_accuracy:  74.057 validate_loss:  1.721 validate_accuracy:  54.845
epoch: 70 train_loss:  0.865 train_accuracy:  74.775 validate_loss:  1.813 validate_accuracy:  54.760
epoch: 71 train_loss:  0.896 train_accuracy:  73.755 validate_loss:  1.762 validate_accuracy:  53.755
epoch: 72 train_loss:  0.832 train_accuracy:  74.951 validate_loss:  1.748 validate_accuracy:  55.196
epoch: 73 train_loss:  0.842 train_accuracy:  75.034 validate_loss:  1.670 validate_accuracy:  55.947
epoch: 74 train_loss:  0.889 train_accuracy:  73.986 validate_loss:  1.761 validate_accuracy:  55.911
epoch: 75 train_loss:  0.861 train_accuracy:  74.596 validate_loss:  1.804 validate_accuracy:  54.881
epoch: 76 train_loss:  0.862 train_accuracy:  75.148 validate_loss:  1.761 validate_accuracy:  54.142
epoch: 77 train_loss:  0.833 train_accuracy:  75.693 validate_loss:  1.736 validate_accuracy:  55.329
epoch: 78 train_loss:  0.887 train_accuracy:  74.063 validate_loss:  1.694 validate_accuracy:  55.305
epoch: 79 train_loss:  0.822 train_accuracy:  75.324 validate_loss:  1.733 validate_accuracy:  56.274
epoch: 80 train_loss:  0.828 train_accuracy:  75.884 validate_loss:  1.680 validate_accuracy:  56.177
epoch: 81 train_loss:  0.839 train_accuracy:  75.173 validate_loss:  1.690 validate_accuracy:  55.947
epoch: 82 train_loss:  0.860 train_accuracy:  74.609 validate_loss:  1.784 validate_accuracy:  55.463
epoch: 83 train_loss:  0.846 train_accuracy:  75.247 validate_loss:  1.783 validate_accuracy:  57.122
epoch: 84 train_loss:  0.852 train_accuracy:  75.083 validate_loss:  1.808 validate_accuracy:  55.669
epoch: 85 train_loss:  0.845 train_accuracy:  75.339 validate_loss:  1.755 validate_accuracy:  56.347
epoch: 86 train_loss:  0.832 train_accuracy:  75.592 validate_loss:  1.766 validate_accuracy:  54.784
epoch: 87 train_loss:  0.850 train_accuracy:  75.592 validate_loss:  1.813 validate_accuracy:  53.791
epoch: 88 train_loss:  0.872 train_accuracy:  74.528 validate_loss:  1.943 validate_accuracy:  53.719
epoch: 89 train_loss:  0.848 train_accuracy:  75.151 validate_loss:  1.756 validate_accuracy:  56.904
epoch: 90 train_loss:  0.914 train_accuracy:  73.567 validate_loss:  1.770 validate_accuracy:  54.700
epoch: 91 train_loss:  0.877 train_accuracy:  74.491 validate_loss:  1.850 validate_accuracy:  53.852
epoch: 92 train_loss:  0.883 train_accuracy:  74.935 validate_loss:  1.777 validate_accuracy:  55.814
epoch: 93 train_loss:  0.846 train_accuracy:  75.542 validate_loss:  1.776 validate_accuracy:  56.202
epoch: 94 train_loss:  0.866 train_accuracy:  74.732 validate_loss:  1.743 validate_accuracy:  56.141
epoch: 95 train_loss:  0.823 train_accuracy:  75.906 validate_loss:  1.854 validate_accuracy:  55.596
epoch: 96 train_loss:  0.842 train_accuracy:  75.413 validate_loss:  1.725 validate_accuracy:  55.281
epoch: 97 train_loss:  0.839 train_accuracy:  75.881 validate_loss:  1.809 validate_accuracy:  55.632
epoch: 98 train_loss:  0.832 train_accuracy:  75.977 validate_loss:  1.862 validate_accuracy:  54.227
epoch: 99 train_loss:  0.835 train_accuracy:  75.675 validate_loss:  1.756 validate_accuracy:  55.245
epoch: 100 train_loss:  0.844 train_accuracy:  75.370 validate_loss:  1.864 validate_accuracy:  54.554
In [18]:
class CNN_compared(nn.Module):
    def __init__(self):
        super(CNN_compared,self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flc1 = nn.Linear(64*28*28,1024)
        self.dropout = nn.Dropout(p=0.3)
        self.flc2 = nn.Linear(1024,30)

    def forward(self,x):
        x = self.maxpool1(nn.functional.relu(self.conv1(x)))
        x = self.maxpool2(nn.functional.relu(self.conv2(x)))
        x = self.maxpool3(nn.functional.relu(self.conv3(x)))
        x = x.view(-1,64*28*28)
        x = self.dropout(x)
        x = nn.functional.relu(self.flc1(x))
        x = self.flc2(x)
        return x
In [29]:
CNN_compare_model = CNN_compared()
CNN_compare_model = CNN_compare_model.to(device)
# Your code here! 
cnn_start = time.time()
nepochs = 100
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(CNN_compare_model.parameters(), 0.001)

CNN_compare_best_loss = 1000.0
CNN_compare_train_loss, CNN_compare_train_accuracy= [], []
CNN_compare_validate_loss, CNN_compare_validate_accuracy= [], []

for epoch in range(nepochs):
    CNN_compare_train_running_loss , CNN_compare_train_running_accuracy = train(train_loader, CNN_compare_model, criterion, optimizer)
    CNN_compare_train_loss.append(CNN_compare_train_running_loss)
    CNN_compare_train_accuracy.append(CNN_compare_train_running_accuracy)
    CNN_compare_validate_running_loss , CNN_compare_validate_running_accuracy = validate(validate_loader, CNN_compare_model, criterion, optimizer)
    CNN_compare_validate_loss.append(CNN_compare_validate_running_loss)
    CNN_compare_validate_accuracy.append(CNN_compare_validate_running_accuracy)
    if CNN_compare_validate_running_loss < CNN_compare_best_loss:
        CNN_compare_best_loss = CNN_compare_validate_running_loss
        torch.save(CNN_compare_model.state_dict(), './cnn_compare.pt')
    print(f"epoch: {epoch+1} train_loss: {CNN_compare_train_running_loss : .3f} train_accuracy: {CNN_compare_train_running_accuracy : .3f} validate_loss: {CNN_compare_validate_running_loss : .3f} validate_accuracy: {CNN_compare_validate_running_accuracy : .3f}")
cnn_end = time.time()
cnn_running_time = cnn_end - cnn_start
epoch: 1 train_loss:  3.293 train_accuracy:  9.159 validate_loss:  3.021 validate_accuracy:  13.118
epoch: 2 train_loss:  2.967 train_accuracy:  15.178 validate_loss:  2.878 validate_accuracy:  17.696
epoch: 3 train_loss:  2.739 train_accuracy:  20.707 validate_loss:  2.666 validate_accuracy:  23.462
epoch: 4 train_loss:  2.599 train_accuracy:  24.797 validate_loss:  2.595 validate_accuracy:  25.230
epoch: 5 train_loss:  2.511 train_accuracy:  27.197 validate_loss:  2.544 validate_accuracy:  26.441
epoch: 6 train_loss:  2.420 train_accuracy:  29.370 validate_loss:  2.435 validate_accuracy:  30.245
epoch: 7 train_loss:  2.354 train_accuracy:  31.161 validate_loss:  2.444 validate_accuracy:  30.015
epoch: 8 train_loss:  2.284 train_accuracy:  33.343 validate_loss:  2.338 validate_accuracy:  32.401
epoch: 9 train_loss:  2.203 train_accuracy:  35.392 validate_loss:  2.336 validate_accuracy:  33.321
epoch: 10 train_loss:  2.126 train_accuracy:  37.133 validate_loss:  2.305 validate_accuracy:  34.096
epoch: 11 train_loss:  2.056 train_accuracy:  39.614 validate_loss:  2.316 validate_accuracy:  33.188
epoch: 12 train_loss:  2.004 train_accuracy:  41.306 validate_loss:  2.270 validate_accuracy:  34.133
epoch: 13 train_loss:  1.938 train_accuracy:  42.490 validate_loss:  2.270 validate_accuracy:  34.823
epoch: 14 train_loss:  1.887 train_accuracy:  44.181 validate_loss:  2.208 validate_accuracy:  36.955
epoch: 15 train_loss:  1.823 train_accuracy:  45.750 validate_loss:  2.248 validate_accuracy:  36.313
epoch: 16 train_loss:  1.773 train_accuracy:  46.641 validate_loss:  2.263 validate_accuracy:  37.621
epoch: 17 train_loss:  1.735 train_accuracy:  48.724 validate_loss:  2.278 validate_accuracy:  36.470
epoch: 18 train_loss:  1.671 train_accuracy:  50.120 validate_loss:  2.203 validate_accuracy:  37.984
epoch: 19 train_loss:  1.615 train_accuracy:  51.624 validate_loss:  2.309 validate_accuracy:  37.052
epoch: 20 train_loss:  1.557 train_accuracy:  53.985 validate_loss:  2.216 validate_accuracy:  38.748
epoch: 21 train_loss:  1.515 train_accuracy:  54.367 validate_loss:  2.300 validate_accuracy:  37.645
epoch: 22 train_loss:  1.469 train_accuracy:  55.711 validate_loss:  2.257 validate_accuracy:  39.087
epoch: 23 train_loss:  1.414 train_accuracy:  57.128 validate_loss:  2.335 validate_accuracy:  36.906
epoch: 24 train_loss:  1.368 train_accuracy:  58.728 validate_loss:  2.313 validate_accuracy:  38.033
epoch: 25 train_loss:  1.324 train_accuracy:  59.970 validate_loss:  2.422 validate_accuracy:  38.154
epoch: 26 train_loss:  1.285 train_accuracy:  60.793 validate_loss:  2.357 validate_accuracy:  38.299
epoch: 27 train_loss:  1.245 train_accuracy:  62.432 validate_loss:  2.369 validate_accuracy:  38.578
epoch: 28 train_loss:  1.222 train_accuracy:  62.913 validate_loss:  2.409 validate_accuracy:  37.330
epoch: 29 train_loss:  1.147 train_accuracy:  64.925 validate_loss:  2.483 validate_accuracy:  38.009
epoch: 30 train_loss:  1.137 train_accuracy:  65.243 validate_loss:  2.422 validate_accuracy:  40.007
epoch: 31 train_loss:  1.112 train_accuracy:  66.355 validate_loss:  2.556 validate_accuracy:  36.701
epoch: 32 train_loss:  1.075 train_accuracy:  67.172 validate_loss:  2.462 validate_accuracy:  37.028
epoch: 33 train_loss:  1.050 train_accuracy:  68.371 validate_loss:  2.510 validate_accuracy:  37.391
epoch: 34 train_loss:  0.990 train_accuracy:  69.576 validate_loss:  2.611 validate_accuracy:  38.178
epoch: 35 train_loss:  0.981 train_accuracy:  70.081 validate_loss:  2.607 validate_accuracy:  38.505
epoch: 36 train_loss:  0.933 train_accuracy:  71.182 validate_loss:  2.665 validate_accuracy:  38.832
epoch: 37 train_loss:  0.915 train_accuracy:  72.165 validate_loss:  2.571 validate_accuracy:  38.639
epoch: 38 train_loss:  0.907 train_accuracy:  72.809 validate_loss:  2.680 validate_accuracy:  38.675
epoch: 39 train_loss:  0.877 train_accuracy:  72.602 validate_loss:  2.735 validate_accuracy:  36.967
epoch: 40 train_loss:  0.850 train_accuracy:  74.082 validate_loss:  2.639 validate_accuracy:  37.924
epoch: 41 train_loss:  0.826 train_accuracy:  74.328 validate_loss:  2.842 validate_accuracy:  37.609
epoch: 42 train_loss:  0.820 train_accuracy:  74.784 validate_loss:  2.658 validate_accuracy:  39.050
epoch: 43 train_loss:  0.782 train_accuracy:  76.193 validate_loss:  2.801 validate_accuracy:  38.651
epoch: 44 train_loss:  0.766 train_accuracy:  76.498 validate_loss:  2.822 validate_accuracy:  38.869
epoch: 45 train_loss:  0.764 train_accuracy:  76.436 validate_loss:  2.869 validate_accuracy:  37.149
epoch: 46 train_loss:  0.728 train_accuracy:  77.632 validate_loss:  2.796 validate_accuracy:  37.827
epoch: 47 train_loss:  0.726 train_accuracy:  77.962 validate_loss:  2.849 validate_accuracy:  38.978
epoch: 48 train_loss:  0.717 train_accuracy:  77.835 validate_loss:  2.976 validate_accuracy:  37.209
epoch: 49 train_loss:  0.696 train_accuracy:  78.359 validate_loss:  2.927 validate_accuracy:  39.293
epoch: 50 train_loss:  0.654 train_accuracy:  79.946 validate_loss:  3.114 validate_accuracy:  37.088
epoch: 51 train_loss:  0.655 train_accuracy:  79.617 validate_loss:  2.909 validate_accuracy:  37.536
epoch: 52 train_loss:  0.660 train_accuracy:  79.706 validate_loss:  2.931 validate_accuracy:  38.857
epoch: 53 train_loss:  0.620 train_accuracy:  80.464 validate_loss:  3.119 validate_accuracy:  37.573
epoch: 54 train_loss:  0.597 train_accuracy:  81.561 validate_loss:  3.084 validate_accuracy:  38.094
epoch: 55 train_loss:  0.589 train_accuracy:  81.950 validate_loss:  3.116 validate_accuracy:  38.699
epoch: 56 train_loss:  0.612 train_accuracy:  80.886 validate_loss:  3.231 validate_accuracy:  37.754
epoch: 57 train_loss:  0.573 train_accuracy:  82.215 validate_loss:  3.139 validate_accuracy:  38.190
epoch: 58 train_loss:  0.577 train_accuracy:  81.740 validate_loss:  3.138 validate_accuracy:  37.875
epoch: 59 train_loss:  0.574 train_accuracy:  82.261 validate_loss:  3.119 validate_accuracy:  38.142
epoch: 60 train_loss:  0.556 train_accuracy:  82.640 validate_loss:  3.278 validate_accuracy:  37.779
epoch: 61 train_loss:  0.544 train_accuracy:  82.683 validate_loss:  3.191 validate_accuracy:  38.614
epoch: 62 train_loss:  0.534 train_accuracy:  83.802 validate_loss:  3.258 validate_accuracy:  37.548
epoch: 63 train_loss:  0.519 train_accuracy:  83.645 validate_loss:  3.264 validate_accuracy:  38.154
epoch: 64 train_loss:  0.522 train_accuracy:  83.460 validate_loss:  3.243 validate_accuracy:  38.069
epoch: 65 train_loss:  0.489 train_accuracy:  84.751 validate_loss:  3.396 validate_accuracy:  37.064
epoch: 66 train_loss:  0.500 train_accuracy:  84.600 validate_loss:  3.404 validate_accuracy:  38.953
epoch: 67 train_loss:  0.495 train_accuracy:  84.477 validate_loss:  3.234 validate_accuracy:  39.268
epoch: 68 train_loss:  0.502 train_accuracy:  84.532 validate_loss:  3.371 validate_accuracy:  38.009
epoch: 69 train_loss:  0.468 train_accuracy:  85.660 validate_loss:  3.432 validate_accuracy:  36.652
epoch: 70 train_loss:  0.476 train_accuracy:  85.306 validate_loss:  3.411 validate_accuracy:  37.258
epoch: 71 train_loss:  0.493 train_accuracy:  84.850 validate_loss:  3.443 validate_accuracy:  36.761
epoch: 72 train_loss:  0.467 train_accuracy:  85.370 validate_loss:  3.412 validate_accuracy:  38.227
epoch: 73 train_loss:  0.467 train_accuracy:  84.998 validate_loss:  3.360 validate_accuracy:  37.948
epoch: 74 train_loss:  0.462 train_accuracy:  85.404 validate_loss:  3.426 validate_accuracy:  38.057
epoch: 75 train_loss:  0.429 train_accuracy:  86.418 validate_loss:  3.613 validate_accuracy:  38.178
epoch: 76 train_loss:  0.439 train_accuracy:  86.138 validate_loss:  3.534 validate_accuracy:  36.834
epoch: 77 train_loss:  0.445 train_accuracy:  86.378 validate_loss:  3.481 validate_accuracy:  38.227
epoch: 78 train_loss:  0.444 train_accuracy:  86.061 validate_loss:  3.563 validate_accuracy:  37.016
epoch: 79 train_loss:  0.400 train_accuracy:  87.016 validate_loss:  3.649 validate_accuracy:  36.240
epoch: 80 train_loss:  0.413 train_accuracy:  87.186 validate_loss:  3.695 validate_accuracy:  37.452
epoch: 81 train_loss:  0.415 train_accuracy:  87.099 validate_loss:  3.598 validate_accuracy:  38.081
epoch: 82 train_loss:  0.407 train_accuracy:  87.737 validate_loss:  3.718 validate_accuracy:  37.536
epoch: 83 train_loss:  0.392 train_accuracy:  87.913 validate_loss:  3.673 validate_accuracy:  37.670
epoch: 84 train_loss:  0.392 train_accuracy:  87.861 validate_loss:  3.589 validate_accuracy:  39.147
epoch: 85 train_loss:  0.411 train_accuracy:  87.272 validate_loss:  3.622 validate_accuracy:  37.597
epoch: 86 train_loss:  0.392 train_accuracy:  87.777 validate_loss:  3.716 validate_accuracy:  37.173
epoch: 87 train_loss:  0.391 train_accuracy:  87.706 validate_loss:  3.713 validate_accuracy:  38.542
epoch: 88 train_loss:  0.366 train_accuracy:  88.240 validate_loss:  3.733 validate_accuracy:  38.469
epoch: 89 train_loss:  0.380 train_accuracy:  87.972 validate_loss:  3.832 validate_accuracy:  35.744
epoch: 90 train_loss:  0.387 train_accuracy:  88.332 validate_loss:  3.761 validate_accuracy:  38.348
epoch: 91 train_loss:  0.382 train_accuracy:  88.107 validate_loss:  3.777 validate_accuracy:  38.542
epoch: 92 train_loss:  0.352 train_accuracy:  88.671 validate_loss:  3.834 validate_accuracy:  37.270
epoch: 93 train_loss:  0.359 train_accuracy:  88.911 validate_loss:  3.713 validate_accuracy:  37.972
epoch: 94 train_loss:  0.375 train_accuracy:  88.613 validate_loss:  3.731 validate_accuracy:  38.106
epoch: 95 train_loss:  0.339 train_accuracy:  89.294 validate_loss:  3.757 validate_accuracy:  38.578
epoch: 96 train_loss:  0.364 train_accuracy:  88.650 validate_loss:  3.875 validate_accuracy:  37.718
epoch: 97 train_loss:  0.345 train_accuracy:  88.995 validate_loss:  3.580 validate_accuracy:  39.014
epoch: 98 train_loss:  0.334 train_accuracy:  89.586 validate_loss:  3.875 validate_accuracy:  37.561
epoch: 99 train_loss:  0.346 train_accuracy:  88.985 validate_loss:  3.789 validate_accuracy:  38.493
epoch: 100 train_loss:  0.338 train_accuracy:  89.534 validate_loss:  3.988 validate_accuracy:  37.779

5.2 Compare results on validation set of TinyImageNet30¶

Loss graph, top1 accuracy, confusion matrix and execution time for your model (say, mymodel and AlexNet)

In [22]:
# Your code here! 
# Loss and accuracy graph 
x_axis = np.arange(1,nepochs+1,1,int)
fig,axs=plt.subplots(1,2,figsize=(15,20),sharex=False,sharey=False)
fig.suptitle('Compare CNN and AlexNet in validate sets ')

axs[0].plot(x_axis,alexnet_compare_validate_loss,label='AlexNet validate_loss')
axs[0].plot(x_axis,CNN_compare_validate_loss,label='CNN validate_loss')
axs[1].plot(x_axis,alexnet_compare_validate_accuracy,label='AlexNet validate_accuracy')
axs[1].plot(x_axis, CNN_compare_validate_accuracy,label='CNN validate_accuracy')

axs[0].set_xlabel('epoch')
axs[1].set_xlabel('epoch')
axs[0].set_ylabel('loss')
axs[1].set_ylabel('percentage of accuracy')
axs[0].legend()
axs[1].legend()

plt.show()
In [26]:
# confusion matrix 
num_class = len(classes)
nclasses = len(classes)
CNN_compare_model.load_state_dict(torch.load( './cnn_compare.pt'))
alexnet_compare.load_state_dict(torch.load( './alexnet_compare.pt'))

cnfm_cnn = np.zeros((nclasses,nclasses),dtype=int)
cnfm_alexnet = np.zeros((nclasses,nclasses),dtype=int)

with torch.no_grad():
    for data in validate_loader:
        images, labels = data
        images = images.to(device)
        labels = labels.to(device)
        
        CNN_outputs = CNN_compare_model(images)
        _, CNN_predicted = torch.max(CNN_outputs, 1)    
        CNN_score_tmp = CNN_outputs
        for i in range(labels.size(0)):
            cnfm_cnn[labels[i].item(),CNN_predicted[i].item()] += 1
        
        alexnet_outputs = alexnet_compare(images)
        _, alexnet_predicted = torch.max(alexnet_outputs, 1)    
        alexnet_score_tmp = alexnet_outputs
        for i in range(labels.size(0)):
            cnfm_alexnet[labels[i].item(),alexnet_predicted[i].item()] += 1
        
print("CNN Model Confusion Matrix")
print(cnfm_cnn)

# show confusion matrix as a grey-level image
plt.imshow(cnfm_cnn, cmap='gray')
CNN Model Confusion Matrix
[[28  0  0  5  1  1  2  4  7  3  2  1  3  2  0  0  1  0  7  1  5  0  1 12
   0  0  2  0  2  0]
 [ 1 22  0  0  1  0  0  1  0  2  1 11  0  0  4 10  0  1  0  5  0  0  2  1
   0  7  0  0  1  1]
 [ 0  1 25  0  6  0  1  3  1  0 29  1  0  2  1  2  2  0  0  6  2  0  1  0
   1  1  2  1  2  1]
 [ 4  0  0 41  0  0  0  2  9  0  0  0 10  0  0  0  2  2  4  0  2  1  0 10
   0  0  3  1  0  1]
 [ 0  1  2  0 84  0  0  0  0  0  2  0  0  2  0  0  0  1  0  2  1  0  1  0
   1  0  3  1  3  0]
 [ 0  6  6  0  0 18  3  4  1  3  0  6  0  1  6  3  2  4  0  4  1  1  2  2
   1  3  2  2  5  3]
 [ 4  0  2  2  0  1 32  1  2  1  1  2  2  2  6  0  1  0  5  1  5  2  3  2
   1  0  5  4  5  0]
 [ 1  2  1  1  0  2  0 21  1  2  1  0  2  1  7  8  3  1  4  3  3  0  2  7
   0  2  5  2  9  0]
 [ 9  0  1 13  0  0  0  3 20  1  3  0  8  1  0  1  1  0 10  0  3  2  2  2
   0  0  1  3  4  1]
 [ 0  2  1  0  1  1  0  2  0 46  2  2  0  2  2  9  4  0  0  1  0  0  0  0
   2  2  1  2  4  1]
 [ 3  1  7  3  3  1  0  1  2  0 49  1  2 10  0  2  1  2  1  3  1  0  0  0
   0  0  0  0  1  3]
 [ 0  9  0  0  1  1  0  2  0  1  0 48  1  0  1  5  1  1  0  7  4  1  0  0
   1  2  1  0  3  2]
 [ 2  1  1  5  0  0  1  4  1  0  5  0 37  1  0  0  2  1  3  0  3  2  3  9
   0  0  2  2  2  1]
 [ 0  2 10  3  6  0  1  2  0  0 14  2  5 22  0  3  1  1  4  6  2  0  1  1
   4  0  1  1  1  1]
 [ 0  2  2  0  0  4  5  4  1  1  4  2  1  0 35  0  8  0  1  1  1  0  6  1
   1  2  0  3  2  3]
 [ 1  0  4  1  0  1  0  2  1  6  2  1  1  3  4 28  2  1  2  3  3  1  0  4
   5  9  2  3  3  1]
 [ 0  0  0  3  0  0  0  3  0  5  1  1  2  0  6  4 24  1  3  2  0  0  2  0
   0  0  2  6  5  2]
 [ 0  1  1  0  0  2  0  1  0  1  1  1  1  0  2  2  0 66  0  0  1  0  0  1
   0  0  0  1  1  0]
 [ 5  0  3  0  0  0  2  3  1  0  0  0  4  0  0  0  0  0 59  0  1  2  0  3
   0  1  6  1  1  1]
 [ 1  6  4  0  3  1  2  1  0  3  7  4  4  3  4  6  0  0  0 37  0  0  1  1
   0  1  3  1  1  4]
 [14  2  3  2  1  1  1  2  1  0  0  0  5  1  2  3  0  0  5  0 31  5  0  5
   4  0  3  2  2  0]
 [ 2  1  2  2  2  1  0  1  0  3  4  7  1  2  3  2  0  4  3  1  4 17  0  1
   3  1  9  1  5  0]
 [ 0  0  0  5  0  1  1  4  4  0  1  0  7  2  5  0  3  1  1  2  0  0 32  0
   0  0  3  4  3  0]
 [ 6  0  0  6  1  0  1  0  6  1  4  0  9  5  1  5  1  0  7  0 10  1  1 12
   0  1  6  4  3  2]
 [ 0  6  2  0  0  1  0  0  1  1  2  0  0  2  0  8  0  0  0  0  1  1  0  0
  51  1  2  1  7  0]
 [ 0  4  0  1  0  1  1  3  1  3  1  2  0  4  2 16  0  0  0  2  3  0  0  0
  10 23  1  2  4  2]
 [ 5  0  2  3  2  1  0  2  0  3  0  0  1  1  1  3  0  0  6  2  1  3  1  2
   0  0 55  0  0  0]
 [ 1  0  1  0  0  4  2  3  2  3  0  0  3  5  7  4  6  0  3  1  0  0  3  4
   1  0  0 34  2  2]
 [ 1  1  4  4  2  1  2  3  1  0  4  0  0  1  8  6  3  2  2  1  0  4  2  6
   2  3  3  0 30  4]
 [ 1  1  3  2  4  1  3  4  2  1  3  1  4  2  5  1  4  3  4  1  0  4  2  4
   2  1  3  4 10 16]]
Out[26]:
<matplotlib.image.AxesImage at 0x1a7662d9550>
In [25]:
print("AlexNet Model Confusion Matrix")
print(cnfm_alexnet)

# show confusion matrix as a grey-level image
plt.imshow(cnfm_alexnet, cmap='gray')
AlexNet Model Confusion Matrix
[[42  0  0  5  0  1  2  3  4  0  1  0  2  0  0  0  1  0 10  0  7  0  0  8
   0  0  1  0  0  3]
 [ 0 28  0  0  1  4  1  0  0  0  3  7  0  1  5  3  1  0  2  0  0  2  2  1
   0  5  0  0  2  3]
 [ 0  0 58  1  1  0  1  4  0  0  9  1  0  3  1  0  2  0  2  0  0  2  0  1
   0  0  2  0  0  3]
 [ 2  0  0 71  0  0  1  1  2  0  0  0  1  0  0  0  0  1  0  0  2  1  0  6
   0  0  1  0  0  3]
 [ 0  1  5  0 93  0  0  0  0  0  0  0  0  1  0  0  0  0  1  0  0  0  0  0
   0  0  1  0  0  2]
 [ 1  5  1  0  0 37  2  3  0  4  1  3  0  0  5  4  0  3  2  1  1  2  0  1
   1  2  0  6  2  2]
 [ 0  3  0  2  0  1 60  2  1  0  0  2  0  1  2  0  0  3  5  0  1  1  2  1
   0  0  0  0  1  4]
 [ 2  0  1  0  0  3  1 51  2  0  0  2  4  1  2  1  1  1  3  0  2  0  1  7
   0  0  2  0  3  1]
 [ 6  0  0  7  0  0  3  0 56  1  0  0  4  0  0  0  1  0  3  0  1  0  0  3
   0  0  2  0  0  2]
 [ 0  0  1  1  0  0  1  0  0 69  0  1  2  0  1  2  2  0  0  0  0  0  1  0
   3  1  1  0  1  0]
 [ 1  1 21  1  0  0  2  1  2  0 48  2  1  4  1  0  0  2  0  3  1  0  0  1
   0  2  1  0  1  1]
 [ 0  3  4  1  0  0  1  1  0  0  0 65  1  0  0  0  0  6  0  0  1  3  0  0
   0  0  4  0  0  2]
 [ 5  0  1  1  1  0  1  0  1  0  0  0 52  0  3  0  1  2  0  0  1  2  2  6
   0  0  1  1  2  5]
 [ 2  1  6  1  1  0  2  1  0  0 10  1  3 44  3  1  0  1  2  1  1  3  0  0
   0  0  4  1  1  4]
 [ 0  0  1  0  0  5  3  2  0  2  0  2  0  2 57  0  2  2  2  0  0  0  1  0
   0  0  1  3  2  3]
 [ 2  1  7  0  0  5  4  8  0  4  1  1  0  5  2 31  0  2  0  1  3  1  0  1
   2  4  0  1  5  3]
 [ 0  1  0  0  0  0  3  1  1  0  0  1  0  1  3  2 49  4  0  0  0  0  0  0
   0  0  0  2  1  3]
 [ 1  0  0  0  0  1  0  2  0  0  0  3  0  0  1  1  0 67  0  0  1  1  0  1
   0  0  0  1  0  3]
 [ 8  0  1  0  0  0  2  2  3  0  0  0  0  0  0  1  0  2 61  0  1  0  0  6
   0  0  4  0  1  1]
 [ 0  1 15  1  1  0  1  0  1  0  8  4  1  4  0  2  2  1  0 49  1  3  0  0
   0  0  3  0  0  0]
 [ 7  0  0  1  0  1  2  5  7  0  0  0  2  1  0  0  0  0  1  0 52  3  0  8
   0  2  2  0  0  1]
 [ 1  1  5  1  1  0  1  3  0  1  0  5  0  0  2  1  1  4  2  0  2 36  0  6
   1  1  5  0  2  0]
 [ 0  0  0  1  0  0  1  0  4  0  1  0  7  0  2  0  1  2  1  0  0  3 49  5
   0  0  0  1  0  1]
 [ 8  0  0 10  0  0  0  2  6  0  1  0  4  1  0  0  0  1  7  0  4  2  0 36
   0  4  4  0  1  2]
 [ 0  1  3  0  0  2  1  1  0  1  1  1  0  1  1  7  0  0  0  0  0  1  0  0
  56  5  0  1  3  1]
 [ 1  2  4  0  1  2  1  3  0  1  1  2  1  1  0  6  1  0  0  2  0  1  0  0
   4 41  1  1  6  3]
 [ 1  0  3  0  0  0  1  0  0  0  0  1  0  1  0  0  0  1  2  3  0  2  0  3
   0  0 75  0  0  1]
 [ 0  0  0  0  0  4  4  1  0  1  0  0  1  0  7  0  3  1  1  1  1  0  0  2
   0  1  0 60  3  0]
 [ 1  3  2  1  0  7  3  6  1  3  0  3  1  0  2  2  4  3  4  0  1  1  1  4
   2  2  0  0 37  6]
 [ 1  1  1  3  0  1  3  6  1  2  1  2  7  1  5  2  1  5  3  2  1  7  1  2
   1  1  0  3  3 29]]
Out[25]:
<matplotlib.image.AxesImage at 0x1a767883610>
In [30]:
alexnet_second = alexnet_running_time%60
alexnet_minute = int((alexnet_running_time-alexnet_second)/60)
alexnet_hour = int(alexnet_minute/60)
alexnet_minute = alexnet_minute%60

cnn_second = cnn_running_time%60
cnn_minute = int((cnn_running_time-cnn_second)/60)
cnn_hour = int(cnn_minute/60)
cnn_minute = cnn_minute%60

print("The train running time of AlexNet model in TinyImageNet30")
print(f"{alexnet_hour} hour {alexnet_minute} minute {alexnet_second :3f} second")

print("The train running time of CNN model in TinyImageNet30")
print(f"{cnn_hour} hour {cnn_minute} minute {cnn_second :3f} second")
The train running time of AlexNet model in TinyImageNet30
1 hour 44 minute 27.141047 second
The train running time of CNN model in TinyImageNet30
1 hour 44 minute 27.369974 second

6 Interpretation of results¶

Please use TinyImageNet30 dataset for all results

6.1-6.2 Implement grad-CAM and visualise results¶

  • Use an existing library to initiate grad-CAM

      - To install: !pip install torchcam
      - Call SmoothGradCAMpp: from torchcam.methods import SmoothGradCAMpp
      - Apply to your model 

You can see the details here: https://github.com/frgfm/torch-cam

  • Apply grad-CAM to your model on at least four correctly classified images
  • Apply grad-CAM on retrained AlexNet on at least four incorrectly classified images

It is recommended to first read the relevant paper Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization, and refer to relevant course material.

HINT for displaying images with grad-CAM:

Display heatmap as a coloured heatmap superimposed onto the original image. We recommend the following steps to get a clear meaningful display.

From torchcam.utils import overlay_mask. But remember to resize your image, normalise it and put a 1 for the batch dimension (e.g, [1, 3, 224, 224])

In [182]:
# Your code here!
class cam_MyDataset(Dataset):
    def __init__(self, data_type, transform=train_transformer):
        '''
        data_type : ["train_set", "test_set"]
        '''
        root_path = "./comp5625M_data_assessment_1/"
        self.data_type = data_type
        data_root = pathlib.Path(root_path+self.data_type+"/"+self.data_type)

        if self.data_type == "train_set":
            all_image_paths = list(data_root.glob("*/*"))
            self.all_image_paths = all_image_paths
            self.all_image_labels = [int(classes[path.parent.name]) for path in all_image_paths]
            self.all_image_paths = [str(path) for path in all_image_paths]
            
            self.transform = transform

        else:
            all_image_paths = list(data_root.glob("*/"))
            self.all_image_paths = [str(path) for path in all_image_paths]
            self.all_image_labels = [str(path) for path in all_image_paths]
            
            self.transform = transform

    def __getitem__(self, index):
        img = cv.imread(self.all_image_paths[index])
        org_img=torch.tensor(img)
        img=self.transform(img)
        label = self.all_image_labels[index]
        return img,label,org_img
    def __len__(self):
        return len(self.all_image_paths)
    

cam_transform = transforms.Compose([
                    transforms.ToPILImage(),
                    transforms.Resize(224),
                    transforms.ToTensor(),
                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
                    ])
cam_train_set = cam_MyDataset("train_set",transform=cam_transform)
cam_train_loader = DataLoader(
    cam_train_set,
    batch_size = 1,
    shuffle =True)
In [183]:
def get_cam_pic(model,train_iter,ifcorr):
    model.eval()
    cam_extractor = SmoothGradCAMpp(model)
    n = 0
    pics=[]
    
    for data in train_iter:
        images,labels,org_img = data
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        activation_map = cam_extractor(outputs.squeeze(0).argmax().item(), out)[0]
        pre=torch.max(outputs.data,1)[1].cpu().numpy()[0]
        true=labels.cpu().numpy()[0]
        if ifcorr:#correct classification
            if true==pre:
                pics.append((images,org_img, activation_map))
                if len(pics)>=4:#4 photoes
                    break
        else:#wrong classification
            if true!=pre:
                pics.append((images,org_img, activation_map))
                if len(pics)>=4:#4 photoes
                    break

    return pics
                    
            
In [184]:
import matplotlib.pyplot as plt
from torchcam.utils import overlay_mask
from torchvision.transforms.functional import normalize, resize

def plot_pic(p):
    _,img,activation_map=p
    img=img.cpu().numpy()[0,:,:,:]
    img=resize(to_pil_image(img), (224, 224))
    activation_map=activation_map.cpu().numpy()[0,:,:]
    result = overlay_mask(img, to_pil_image(activation_map, mode='F'), alpha=0.5)
    plt.figure()
    plt.subplot(1, 2, 1)
    plt.imshow(np.array(result)) 
    plt.axis('off');
    plt.tight_layout()
    
    plt.subplot(1, 2, 2)
    plt.imshow(img) 
    plt.axis('off');
    plt.tight_layout()
    
    plt.show()
In [185]:
cam_cnn=CNN_compared()
cam_cnn.load_state_dict(torch.load('cnn_compare.pt'))
cam_cnn = cam_cnn .to(device)

cnn_pics=get_cam_pic(cam_cnn,cam_train_loader,1)#find correct photo
print(len(cnn_pics))
for p in cnn_pics:
    plot_pic(p)
WARNING:root:no value was provided for `target_layer`, thus set to 'maxpool3'.
4
In [186]:
cam_alexnet = models.alexnet(pretrained=False)
com_num_fc =cam_alexnet.classifier[6].in_features
cam_alexnet.classifier[6] = torch.nn.Linear(in_features=com_num_fc, out_features=30)
cam_alexnet.load_state_dict(torch.load('alexnet_compare.pt'))
cam_alexnet = cam_alexnet.to(device)

alexnet_pics=get_cam_pic(cam_alexnet,cam_train_loader,0)#find wrong photo
print(len(alexnet_pics))
for p in alexnet_pics:
    plot_pic(p)
WARNING:root:no value was provided for `target_layer`, thus set to 'avgpool'.
4

6.3 Your comments on :¶

a) Why model predictions were correct or incorrect? You can support your case from 6.2

The grad-CAM shows which part or parts of the image contribute and have a greater impact on the classification of the image. For correct classification, the model focus tends to be on the more obvious and more representative features. This is especially true for images with simple scenes, a single number and no occlusions. Then, incorrect classifications are often due to the model being inattentive and not better at identifying boundaries, focusing on too much unnecessary and unrepresentative information.

b) What can you do to improve your results further?

Get more data, deepen model layers, data augmentation, increase number of training epochs, transfer learning, adjust parameters, vary learning rate, weight decay, dropout, batch regularization.